package org.neo4j.gds.ml.nodemodels;

import java.util.ArrayList;
import java.util.List;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationTrainPipelineExecutor;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrainPipelineAlgorithmFactory.class */
public class NodeClassificationTrainPipelineAlgorithmFactory extends GraphStoreAlgorithmFactory<NodeClassificationTrainPipelineExecutor, NodeClassificationPipelineTrainConfig> {
    private final ExecutionContext executionContext;

    public NodeClassificationTrainPipelineAlgorithmFactory(ExecutionContext executionContext) {
        this.executionContext = executionContext;
    }

    public NodeClassificationTrainPipelineExecutor build(GraphStore graphStore, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ProgressTracker progressTracker) {
        return new NodeClassificationTrainPipelineExecutor(PipelineCatalog.getTyped(nodeClassificationPipelineTrainConfig.username(), nodeClassificationPipelineTrainConfig.pipeline(), NodeClassificationTrainingPipeline.class), nodeClassificationPipelineTrainConfig, this.executionContext, graphStore, nodeClassificationPipelineTrainConfig.graphName(), progressTracker);
    }

    public MemoryEstimation memoryEstimation(NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig) {
        return MemoryEstimations.builder(NodeClassificationTrainPipelineExecutor.class).add("Pipeline executor", NodeClassificationTrainPipelineExecutor.estimate(PipelineCatalog.getTyped(nodeClassificationPipelineTrainConfig.username(), nodeClassificationPipelineTrainConfig.pipeline(), NodeClassificationTrainingPipeline.class), nodeClassificationPipelineTrainConfig, this.executionContext.modelCatalog())).build();
    }

    public String taskName() {
        return "Node Classification Train Pipeline";
    }

    public Task progressTask(GraphStore graphStore, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig) {
        final NodeClassificationTrainingPipeline typed = PipelineCatalog.getTyped(nodeClassificationPipelineTrainConfig.username(), nodeClassificationPipelineTrainConfig.pipeline(), NodeClassificationTrainingPipeline.class);
        return Tasks.task(taskName(), new ArrayList<Task>() { // from class: org.neo4j.gds.ml.nodemodels.NodeClassificationTrainPipelineAlgorithmFactory.1
            {
                add(Tasks.iterativeFixed("Execute node property steps", () -> {
                    return List.of(Tasks.leaf("Step"));
                }, typed.nodePropertySteps().size()));
                addAll(NodeClassificationTrain.progressTasks(typed.splitConfig().validationFolds(), typed.numberOfModelSelectionTrials()));
            }
        });
    }
}
