package org.neo4j.gds.ml.pipeline.nodePipeline.regression;

import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrainPipelineAlgorithmFactory.class */
public class NodeRegressionTrainPipelineAlgorithmFactory extends GraphStoreAlgorithmFactory<NodeRegressionTrainAlgorithm, NodeRegressionPipelineTrainConfig> {
    private final ExecutionContext executionContext;

    public NodeRegressionTrainPipelineAlgorithmFactory(ExecutionContext executionContext) {
        this.executionContext = executionContext;
    }

    public NodeRegressionTrainAlgorithm build(GraphStore graphStore, NodeRegressionPipelineTrainConfig nodeRegressionPipelineTrainConfig, ProgressTracker progressTracker) {
        return build(graphStore, nodeRegressionPipelineTrainConfig, (NodeRegressionTrainingPipeline) PipelineCatalog.getTyped(nodeRegressionPipelineTrainConfig.username(), nodeRegressionPipelineTrainConfig.pipeline(), NodeRegressionTrainingPipeline.class), progressTracker);
    }

    public NodeRegressionTrainAlgorithm build(GraphStore graphStore, NodeRegressionPipelineTrainConfig nodeRegressionPipelineTrainConfig, NodeRegressionTrainingPipeline nodeRegressionTrainingPipeline, ProgressTracker progressTracker) {
        PipelineCompanion.validateMainMetric(nodeRegressionTrainingPipeline, nodeRegressionPipelineTrainConfig.metrics().get(0).toString());
        return new NodeRegressionTrainAlgorithm(NodeRegressionTrain.create(graphStore, nodeRegressionTrainingPipeline, nodeRegressionPipelineTrainConfig, NodeFeatureProducer.create(graphStore, nodeRegressionPipelineTrainConfig, this.executionContext, progressTracker), progressTracker), nodeRegressionTrainingPipeline, graphStore, nodeRegressionPipelineTrainConfig, progressTracker);
    }

    public String taskName() {
        return "Node Regression Train Pipeline";
    }

    public Task progressTask(GraphStore graphStore, NodeRegressionPipelineTrainConfig nodeRegressionPipelineTrainConfig) {
        return progressTask((NodeRegressionTrainingPipeline) PipelineCatalog.getTyped(nodeRegressionPipelineTrainConfig.username(), nodeRegressionPipelineTrainConfig.pipeline(), NodeRegressionTrainingPipeline.class), graphStore.nodeCount());
    }

    public static Task progressTask(NodeRegressionTrainingPipeline nodeRegressionTrainingPipeline, long j) {
        return NodeRegressionTrain.progressTask(nodeRegressionTrainingPipeline, j);
    }
}
