/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.classification.predict;

import java.nio.file.Path;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.neo4j.gds.VerifyThatModelCanBeStored;
import org.neo4j.gds.compat.GdsVersionInfoProvider;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.validation.BeforeLoadValidation;
import org.neo4j.gds.executor.validation.ValidationConfiguration;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineTrainResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationModelResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainAlgorithm;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainPipelineAlgorithmFactory;
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
import org.neo4j.graphdb.GraphDatabaseService;

@GdsCallable(name="gds.beta.pipeline.nodeClassification.train", description="Trains a node classification model based on a pipeline", executionMode=ExecutionMode.TRAIN)
public class NodeClassificationPipelineTrainSpec
implements AlgorithmSpec<NodeClassificationTrainAlgorithm, NodeClassificationModelResult, NodeClassificationPipelineTrainConfig, Stream<NodeClassificationPipelineTrainResult>, NodeClassificationTrainPipelineAlgorithmFactory> {
    public String name() {
        return "NodeClassificationPipelineTrain";
    }

    public NodeClassificationTrainPipelineAlgorithmFactory algorithmFactory(ExecutionContext executionContext) {
        return new NodeClassificationTrainPipelineAlgorithmFactory(executionContext, GdsVersionInfoProvider.GDS_VERSION_INFO.gdsVersion());
    }

    public NewConfigFunction<NodeClassificationPipelineTrainConfig> newConfigFunction() {
        return NodeClassificationPipelineTrainConfig::of;
    }

    public ComputationResultConsumer<NodeClassificationTrainAlgorithm, NodeClassificationModelResult, NodeClassificationPipelineTrainConfig, Stream<NodeClassificationPipelineTrainResult>> computationResultConsumer() {
        return (computationResult, executionContext) -> {
            if (computationResult.result().isPresent()) {
                Model model = ((NodeClassificationModelResult)computationResult.result().get()).model();
                ModelCatalog modelCatalog = executionContext.modelCatalog();
                modelCatalog.set(model);
                if (((NodeClassificationPipelineTrainConfig)computationResult.config()).storeModelToDisk()) {
                    try {
                        GraphDatabaseService databaseService = (GraphDatabaseService)executionContext.dependencyResolver().resolveDependency(GraphDatabaseService.class);
                        modelCatalog.checkLicenseBeforeStoreModel(databaseService, "Store a model");
                        Path modelDir = modelCatalog.getModelDirectory(databaseService);
                        modelCatalog.store(model.creator(), model.name(), modelDir);
                    }
                    catch (Exception e) {
                        executionContext.log().error("Failed to store model to disk after training.", new Object[]{e.getMessage()});
                        throw e;
                    }
                }
                return Stream.of(this.constructProcResult((ComputationResult<NodeClassificationTrainAlgorithm, NodeClassificationModelResult, NodeClassificationPipelineTrainConfig>)computationResult));
            }
            return Stream.empty();
        };
    }

    public ValidationConfiguration<NodeClassificationPipelineTrainConfig> validationConfig(final ExecutionContext executionContext) {
        return new ValidationConfiguration<NodeClassificationPipelineTrainConfig>(){

            public List<BeforeLoadValidation<NodeClassificationPipelineTrainConfig>> beforeLoadValidations() {
                return List.of(new VerifyThatModelCanBeStored(executionContext.modelCatalog(), executionContext.username(), "NodeClassification"));
            }
        };
    }

    private NodeClassificationPipelineTrainResult constructProcResult(ComputationResult<NodeClassificationTrainAlgorithm, NodeClassificationModelResult, NodeClassificationPipelineTrainConfig> computationResult) {
        Optional transformedResult = computationResult.result();
        return new NodeClassificationPipelineTrainResult(transformedResult, computationResult.computeMillis());
    }
}

