/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.regression;

import java.nio.file.Path;
import java.util.List;
import java.util.stream.Stream;
import org.neo4j.gds.VerifyThatModelCanBeStored;
import org.neo4j.gds.compat.GdsVersionInfoProvider;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.NewConfigFunction;
import org.neo4j.gds.executor.validation.BeforeLoadValidation;
import org.neo4j.gds.executor.validation.ValidationConfiguration;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.pipeline.node.regression.TrainResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainAlgorithm;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainPipelineAlgorithmFactory;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainResult;
import org.neo4j.graphdb.GraphDatabaseService;

@GdsCallable(name="gds.alpha.pipeline.nodeRegression.train", description="Trains a node regression model based on a pipeline", executionMode=ExecutionMode.TRAIN)
public class NodeRegressionPipelineTrainSpec
implements AlgorithmSpec<NodeRegressionTrainAlgorithm, NodeRegressionTrainResult.NodeRegressionTrainPipelineResult, NodeRegressionPipelineTrainConfig, Stream<TrainResult>, NodeRegressionTrainPipelineAlgorithmFactory> {
    public String name() {
        return "NodeRegressionPipelineTrain";
    }

    public NodeRegressionTrainPipelineAlgorithmFactory algorithmFactory(ExecutionContext executionContext) {
        String gdsVersion = GdsVersionInfoProvider.GDS_VERSION_INFO.gdsVersion();
        return new NodeRegressionTrainPipelineAlgorithmFactory(executionContext, gdsVersion);
    }

    public NewConfigFunction<NodeRegressionPipelineTrainConfig> newConfigFunction() {
        return NodeRegressionPipelineTrainConfig::of;
    }

    public ComputationResultConsumer<NodeRegressionTrainAlgorithm, NodeRegressionTrainResult.NodeRegressionTrainPipelineResult, NodeRegressionPipelineTrainConfig, Stream<TrainResult>> computationResultConsumer() {
        return (computationResult, executionContext) -> computationResult.result().map(result -> {
            Model model = result.model();
            ModelCatalog modelCatalog = executionContext.modelCatalog();
            assert (modelCatalog != null) : "ModelCatalog should have been set in the ExecutionContext by this point!!!";
            modelCatalog.set(model);
            if (((NodeRegressionPipelineTrainConfig)computationResult.config()).storeModelToDisk()) {
                try {
                    GraphDatabaseService databaseService = (GraphDatabaseService)executionContext.dependencyResolver().resolveDependency(GraphDatabaseService.class);
                    modelCatalog.checkLicenseBeforeStoreModel(databaseService, "Store a model");
                    Path modelDir = modelCatalog.getModelDirectory(databaseService);
                    modelCatalog.store(model.creator(), model.name(), modelDir);
                }
                catch (Exception e) {
                    executionContext.log().error("Failed to store model to disk after training.", new Object[]{e.getMessage()});
                    throw e;
                }
            }
            return Stream.of(new TrainResult((Model<Regressor.RegressorData, NodeRegressionPipelineTrainConfig, NodeRegressionPipelineModelInfo>)model, result.trainingStatistics(), computationResult.computeMillis()));
        }).orElseGet(Stream::empty);
    }

    public ValidationConfiguration<NodeRegressionPipelineTrainConfig> validationConfig(final ExecutionContext executionContext) {
        return new ValidationConfiguration<NodeRegressionPipelineTrainConfig>(){

            public List<BeforeLoadValidation<NodeRegressionPipelineTrainConfig>> beforeLoadValidations() {
                ModelCatalog modelCatalog = executionContext.modelCatalog();
                assert (modelCatalog != null) : "ModelCatalog should have been set in the ExecutionContext by this point!!!";
                return List.of(new VerifyThatModelCanBeStored(modelCatalog, executionContext.username(), "LinkPrediction"));
            }
        };
    }
}

