package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.ml.training.TrainingStatistics;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutor.class */
public class NodeClassificationTrainPipelineExecutor extends PipelineExecutor<NodeClassificationPipelineTrainConfig, NodeClassificationTrainingPipeline, NodeClassificationTrainPipelineResult> {

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutor$NodeClassificationTrainPipelineResult.class */
    public interface NodeClassificationTrainPipelineResult {
        Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model();

        TrainingStatistics trainingStatistics();
    }

    public NodeClassificationTrainPipelineExecutor(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ExecutionContext executionContext, GraphStore graphStore, String str, ProgressTracker progressTracker) {
        super(nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig, executionContext, graphStore, str, progressTracker);
    }

    public static Task progressTask(String str, final NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, final long j) {
        return Tasks.task(str, new ArrayList<Task>() { // from class: org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainPipelineExecutor.1
            {
                add(NodeClassificationTrainPipelineExecutor.nodePropertyStepTasks(NodeClassificationTrainingPipeline.this.nodePropertySteps(), j));
                addAll(NodeClassificationTrain.progressTasks(NodeClassificationTrainingPipeline.this.splitConfig(), NodeClassificationTrainingPipeline.this.numberOfModelSelectionTrials(), j));
            }
        });
    }

    public static MemoryEstimation estimate(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ModelCatalog modelCatalog) {
        PipelineExecutor.validateTrainingParameterSpace(nodeClassificationTrainingPipeline);
        return MemoryEstimations.maxEstimation("Pipeline executor", List.of(PipelineExecutor.estimateNodePropertySteps(modelCatalog, nodeClassificationTrainingPipeline.nodePropertySteps(), nodeClassificationPipelineTrainConfig.nodeLabels(), nodeClassificationPipelineTrainConfig.relationshipTypes()), MemoryEstimations.builder().add("Pipeline Train", NodeClassificationTrain.estimate(nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig)).build()));
    }

    @Override // org.neo4j.gds.ml.pipeline.PipelineExecutor
    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of(((NodeClassificationPipelineTrainConfig) this.config).nodeLabelIdentifiers(this.graphStore), ((NodeClassificationPipelineTrainConfig) this.config).internalRelationshipTypes(this.graphStore)));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.pipeline.PipelineExecutor
    protected NodeClassificationTrainPipelineResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> map) {
        PipelineExecutor.validateTrainingParameterSpace((TrainingPipeline) this.pipeline);
        Graph graph = this.graphStore.getGraph(((NodeClassificationPipelineTrainConfig) this.config).nodeLabelIdentifiers(this.graphStore));
        ((NodeClassificationTrainingPipeline) this.pipeline).splitConfig().validateMinNumNodesInSplitSets(graph);
        NodeClassificationTrainResult compute = NodeClassificationTrain.create(graph, (NodeClassificationTrainingPipeline) this.pipeline, (NodeClassificationPipelineTrainConfig) this.config, this.progressTracker, this.terminationFlag).compute();
        return ImmutableNodeClassificationTrainPipelineResult.of(Model.of(((NodeClassificationPipelineTrainConfig) this.config).username(), ((NodeClassificationPipelineTrainConfig) this.config).modelName(), NodeClassificationTrainingPipeline.MODEL_TYPE, this.schemaBeforeSteps, compute.classifier().data(), (NodeClassificationPipelineTrainConfig) this.config, NodeClassificationPipelineModelInfo.of(compute.trainingStatistics().winningModelTestMetrics(), compute.trainingStatistics().winningModelOuterTrainMetrics(), compute.trainingStatistics().bestCandidate(), NodePropertyPredictPipeline.from(this.pipeline), compute.classifier().classIdMap().originalIdsList())), compute.trainingStatistics());
    }

    @Override // org.neo4j.gds.ml.pipeline.PipelineExecutor
    protected /* bridge */ /* synthetic */ NodeClassificationTrainPipelineResult execute(Map map) {
        return execute((Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter>) map);
    }
}
