package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;

import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineAlgorithmFactory.class */
public class NodeClassificationTrainPipelineAlgorithmFactory extends GraphStoreAlgorithmFactory<NodeClassificationTrainAlgorithm, NodeClassificationPipelineTrainConfig> {
    private final ExecutionContext executionContext;
    private final String gdsVersion;

    public NodeClassificationTrainPipelineAlgorithmFactory(ExecutionContext executionContext, String str) {
        this.executionContext = executionContext;
        this.gdsVersion = str;
    }

    public NodeClassificationTrainAlgorithm build(GraphStore graphStore, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ProgressTracker progressTracker) {
        return build(graphStore, nodeClassificationPipelineTrainConfig, (NodeClassificationTrainingPipeline) PipelineCatalog.getTyped(nodeClassificationPipelineTrainConfig.username(), nodeClassificationPipelineTrainConfig.pipeline(), NodeClassificationTrainingPipeline.class), progressTracker);
    }

    public NodeClassificationTrainAlgorithm build(GraphStore graphStore, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, ProgressTracker progressTracker) {
        PipelineCompanion.validateMainMetric(nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig.metrics().get(0).toString());
        NodeFeatureProducer create = NodeFeatureProducer.create(graphStore, nodeClassificationPipelineTrainConfig, this.executionContext, progressTracker);
        create.validateNodePropertyStepsContextConfigs(nodeClassificationTrainingPipeline.nodePropertySteps());
        return new NodeClassificationTrainAlgorithm(NodeClassificationTrain.create(graphStore, nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig, create, progressTracker), nodeClassificationTrainingPipeline, graphStore, nodeClassificationPipelineTrainConfig, progressTracker);
    }

    public MemoryEstimation memoryEstimation(NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig) {
        return MemoryEstimations.builder(NodeClassificationTrain.class.getSimpleName()).add(NodeClassificationTrain.estimate((NodeClassificationTrainingPipeline) PipelineCatalog.getTyped(nodeClassificationPipelineTrainConfig.username(), nodeClassificationPipelineTrainConfig.pipeline(), NodeClassificationTrainingPipeline.class), nodeClassificationPipelineTrainConfig, this.executionContext.modelCatalog(), this.executionContext.algorithmsProcedureFacade())).build();
    }

    public String taskName() {
        return "Node Classification Train Pipeline";
    }

    public Task progressTask(GraphStore graphStore, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig) {
        return progressTask(graphStore, (NodeClassificationTrainingPipeline) PipelineCatalog.getTyped(nodeClassificationPipelineTrainConfig.username(), nodeClassificationPipelineTrainConfig.pipeline(), NodeClassificationTrainingPipeline.class));
    }

    public static Task progressTask(GraphStore graphStore, NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline) {
        return NodeClassificationTrain.progressTask(nodeClassificationTrainingPipeline, graphStore.nodeCount());
    }
}
