/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.classification.predict;

import java.util.Collection;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPredictPipelineExecutor;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;

public class NodeClassificationPredictPipelineAlgorithmFactory<CONFIG extends NodeClassificationPredictPipelineBaseConfig>
extends GraphStoreAlgorithmFactory<NodeClassificationPredictPipelineExecutor, CONFIG> {
    private final ModelCatalog modelCatalog;
    private final ExecutionContext executionContext;

    NodeClassificationPredictPipelineAlgorithmFactory(ExecutionContext executionContext) {
        this.modelCatalog = executionContext.modelCatalog();
        this.executionContext = executionContext;
    }

    public Task progressTask(GraphStore graphStore, CONFIG config) {
        NodePropertyPredictPipeline trainingPipeline = ((NodeClassificationPipelineModelInfo)NodeClassificationPredictPipelineAlgorithmFactory.getTrainedNCPipelineModel(this.modelCatalog, config.modelName(), config.username()).customInfo()).pipeline();
        return NodeClassificationPredictPipelineExecutor.progressTask(this.taskName(), trainingPipeline, graphStore);
    }

    public String taskName() {
        return "Node Classification Predict Pipeline";
    }

    public NodeClassificationPredictPipelineExecutor build(GraphStore graphStore, CONFIG configuration, ProgressTracker progressTracker) {
        Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model = NodeClassificationPredictPipelineAlgorithmFactory.getTrainedNCPipelineModel(this.modelCatalog, configuration.modelName(), configuration.username());
        NodePropertyPredictPipeline nodeClassificationPipeline = ((NodeClassificationPipelineModelInfo)model.customInfo()).pipeline();
        LocalIdMap classIdMap = LocalIdMap.of((Collection)((NodeClassificationPipelineModelInfo)model.customInfo()).classes());
        return new NodeClassificationPredictPipelineExecutor(nodeClassificationPipeline, (NodeClassificationPredictPipelineBaseConfig)configuration, this.executionContext, graphStore, progressTracker, (Classifier.ClassifierData)model.data(), classIdMap);
    }

    public MemoryEstimation memoryEstimation(CONFIG configuration) {
        Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model = NodeClassificationPredictPipelineAlgorithmFactory.getTrainedNCPipelineModel(this.modelCatalog, configuration.modelName(), configuration.username());
        return MemoryEstimations.builder((String)NodeClassificationPredictPipelineExecutor.class.getSimpleName()).add("Pipeline executor", NodeClassificationPredictPipelineExecutor.estimate(model, configuration, this.modelCatalog)).build();
    }

    private static Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> getTrainedNCPipelineModel(ModelCatalog modelCatalog, String modelName, String username) {
        return modelCatalog.get(username, modelName, Classifier.ClassifierData.class, NodeClassificationPipelineTrainConfig.class, NodeClassificationPipelineModelInfo.class);
    }
}

