package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.List;
import java.util.stream.Stream;
import org.neo4j.gds.VerifyThatModelCanBeStored;
import org.neo4j.gds.compat.GdsVersionInfoProvider;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.NewConfigFunction;
import org.neo4j.gds.executor.validation.BeforeLoadValidation;
import org.neo4j.gds.executor.validation.ValidationConfiguration;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainPipelineExecutor;
import org.neo4j.graphdb.GraphDatabaseService;

@GdsCallable(name = "gds.beta.pipeline.linkPrediction.train", description = "Trains a link prediction model based on a pipeline", executionMode = ExecutionMode.TRAIN)
/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionPipelineTrainSpec.class */
public class LinkPredictionPipelineTrainSpec implements AlgorithmSpec<LinkPredictionTrainPipelineExecutor, LinkPredictionTrainPipelineExecutor.LinkPredictionTrainPipelineResult, LinkPredictionTrainConfig, Stream<TrainResult>, LinkPredictionTrainPipelineAlgorithmFactory> {
    static final /* synthetic */ boolean $assertionsDisabled;

    public String name() {
        return "LinkPredictionPipelineTrain";
    }

    /* renamed from: algorithmFactory, reason: merged with bridge method [inline-methods] */
    public LinkPredictionTrainPipelineAlgorithmFactory m15algorithmFactory(ExecutionContext executionContext) {
        return new LinkPredictionTrainPipelineAlgorithmFactory(executionContext, GdsVersionInfoProvider.GDS_VERSION_INFO.gdsVersion());
    }

    public NewConfigFunction<LinkPredictionTrainConfig> newConfigFunction() {
        return LinkPredictionTrainConfig::of;
    }

    public ComputationResultConsumer<LinkPredictionTrainPipelineExecutor, LinkPredictionTrainPipelineExecutor.LinkPredictionTrainPipelineResult, LinkPredictionTrainConfig, Stream<TrainResult>> computationResultConsumer() {
        return (computationResult, executionContext) -> {
            return (Stream) computationResult.result().map(linkPredictionTrainPipelineResult -> {
                Model model = linkPredictionTrainPipelineResult.model();
                ModelCatalog modelCatalog = executionContext.modelCatalog();
                if (!$assertionsDisabled && modelCatalog == null) {
                    throw new AssertionError("ModelCatalog should have been set in the ExecutionContext by this point!!!");
                }
                modelCatalog.set(model);
                if (computationResult.config().storeModelToDisk()) {
                    try {
                        GraphDatabaseService graphDatabaseService = (GraphDatabaseService) executionContext.dependencyResolver().resolveDependency(GraphDatabaseService.class);
                        modelCatalog.checkLicenseBeforeStoreModel(graphDatabaseService, "Store a model");
                        modelCatalog.store(model.creator(), model.name(), modelCatalog.getModelDirectory(graphDatabaseService));
                    } catch (Exception e) {
                        executionContext.log().error("Failed to store model to disk after training.", new Object[]{e.getMessage()});
                        throw e;
                    }
                }
                return Stream.of(new TrainResult(model, linkPredictionTrainPipelineResult.trainingStatistics(), computationResult.computeMillis()));
            }).orElseGet(Stream::empty);
        };
    }

    public ValidationConfiguration<LinkPredictionTrainConfig> validationConfig(final ExecutionContext executionContext) {
        return new ValidationConfiguration<LinkPredictionTrainConfig>() { // from class: org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionPipelineTrainSpec.1
            static final /* synthetic */ boolean $assertionsDisabled;

            public List<BeforeLoadValidation<LinkPredictionTrainConfig>> beforeLoadValidations() {
                ModelCatalog modelCatalog = executionContext.modelCatalog();
                if ($assertionsDisabled || modelCatalog != null) {
                    return List.of(new VerifyThatModelCanBeStored(modelCatalog, executionContext.username(), "LinkPrediction"));
                }
                throw new AssertionError("ModelCatalog should have been set in the ExecutionContext by this point!!!");
            }

            static {
                $assertionsDisabled = !LinkPredictionPipelineTrainSpec.class.desiredAssertionStatus();
            }
        };
    }

    static {
        $assertionsDisabled = !LinkPredictionPipelineTrainSpec.class.desiredAssertionStatus();
    }
}
