package org.neo4j.gds.ml.pipeline.nodePipeline.regression;

import java.util.ArrayList;
import java.util.Map;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictPipeline;
import org.neo4j.gds.ml.training.TrainingStatistics;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrainPipelineExecutor.class */
public class NodeRegressionTrainPipelineExecutor extends PipelineExecutor<NodeRegressionPipelineTrainConfig, NodeRegressionTrainingPipeline, NodeRegressionTrainPipelineResult> {

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrainPipelineExecutor$NodeRegressionTrainPipelineResult.class */
    public interface NodeRegressionTrainPipelineResult {
        Model<Regressor.RegressorData, NodeRegressionPipelineTrainConfig, NodeRegressionPipelineModelInfo> model();

        TrainingStatistics trainingStatistics();
    }

    public static Task progressTask(final NodeRegressionTrainingPipeline nodeRegressionTrainingPipeline, final long j) {
        return Tasks.task("Node Regression Train Pipeline", new ArrayList<Task>() { // from class: org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainPipelineExecutor.1
            {
                add(NodeRegressionTrainPipelineExecutor.nodePropertyStepTasks(NodeRegressionTrainingPipeline.this.nodePropertySteps(), j));
                addAll(NodeRegressionTrain.progressTasks(NodeRegressionTrainingPipeline.this.splitConfig(), NodeRegressionTrainingPipeline.this.numberOfModelSelectionTrials(), j));
            }
        });
    }

    public NodeRegressionTrainPipelineExecutor(NodeRegressionTrainingPipeline nodeRegressionTrainingPipeline, NodeRegressionPipelineTrainConfig nodeRegressionPipelineTrainConfig, ExecutionContext executionContext, GraphStore graphStore, ProgressTracker progressTracker) {
        super(nodeRegressionTrainingPipeline, nodeRegressionPipelineTrainConfig, executionContext, graphStore, nodeRegressionPipelineTrainConfig.graphName(), progressTracker);
    }

    @Override // org.neo4j.gds.ml.pipeline.PipelineExecutor
    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of(((NodeRegressionPipelineTrainConfig) this.config).nodeLabelIdentifiers(this.graphStore), ((NodeRegressionPipelineTrainConfig) this.config).internalRelationshipTypes(this.graphStore)));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.pipeline.PipelineExecutor
    protected NodeRegressionTrainPipelineResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> map) {
        PipelineExecutor.validateTrainingParameterSpace((TrainingPipeline) this.pipeline);
        Graph graph = this.graphStore.getGraph(((NodeRegressionPipelineTrainConfig) this.config).nodeLabelIdentifiers(this.graphStore));
        ((NodeRegressionTrainingPipeline) this.pipeline).splitConfig().validateMinNumNodesInSplitSets(graph);
        NodeRegressionTrainResult compute = NodeRegressionTrain.create(graph, (NodeRegressionTrainingPipeline) this.pipeline, (NodeRegressionPipelineTrainConfig) this.config, this.progressTracker, this.terminationFlag).compute();
        return ImmutableNodeRegressionTrainPipelineResult.of(Model.of(((NodeRegressionPipelineTrainConfig) this.config).username(), ((NodeRegressionPipelineTrainConfig) this.config).modelName(), NodeRegressionTrainingPipeline.MODEL_TYPE, this.schemaBeforeSteps, compute.regressor().data(), (NodeRegressionPipelineTrainConfig) this.config, NodeRegressionPipelineModelInfo.of(compute.trainingStatistics().winningModelTestMetrics(), compute.trainingStatistics().winningModelOuterTrainMetrics(), compute.trainingStatistics().bestCandidate(), NodePropertyPredictPipeline.from(this.pipeline))), compute.trainingStatistics());
    }

    @Override // org.neo4j.gds.ml.pipeline.PipelineExecutor
    protected /* bridge */ /* synthetic */ NodeRegressionTrainPipelineResult execute(Map map) {
        return execute((Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter>) map);
    }
}
