package org.neo4j.gds.ml.nodemodels.pipeline;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainResult;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/pipeline/NodeClassificationTrainPipelineExecutor.class */
public class NodeClassificationTrainPipelineExecutor extends PipelineExecutor<NodeClassificationPipelineTrainConfig, NodeClassificationTrainingPipeline, NodeClassificationTrainResult> {
    public NodeClassificationTrainPipelineExecutor(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ExecutionContext executionContext, GraphStore graphStore, String str, ProgressTracker progressTracker) {
        super(nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig, executionContext, graphStore, str, progressTracker);
    }

    public static MemoryEstimation estimate(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ModelCatalog modelCatalog) {
        PipelineExecutor.validateTrainingParameterSpace(nodeClassificationTrainingPipeline);
        return MemoryEstimations.maxEstimation("Pipeline executor", List.of(PipelineExecutor.estimateNodePropertySteps(modelCatalog, nodeClassificationTrainingPipeline.nodePropertySteps(), nodeClassificationPipelineTrainConfig.nodeLabels(), nodeClassificationPipelineTrainConfig.relationshipTypes()), MemoryEstimations.builder().add("Pipeline Train", NodeClassificationTrain.estimate(nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig)).build()));
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of(this.config.nodeLabelIdentifiers(this.graphStore), this.config.internalRelationshipTypes(this.graphStore)));
    }

    protected NodeClassificationTrainResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> map) {
        PipelineExecutor.validateTrainingParameterSpace(this.pipeline);
        Graph graph = this.graphStore.getGraph(this.config.nodeLabelIdentifiers(this.graphStore), this.config.internalRelationshipTypes(this.graphStore), Optional.empty());
        this.pipeline.splitConfig().validateMinNumNodesInSplitSets(graph);
        return NodeClassificationTrain.create(graph, this.pipeline, this.config, this.progressTracker, this.terminationFlag).compute();
    }

    /* renamed from: execute, reason: collision with other method in class */
    protected /* bridge */ /* synthetic */ Object m15execute(Map map) {
        return execute((Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter>) map);
    }
}
