/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.classification.predict;

import java.util.Collection;
import java.util.List;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.FeaturesFactory;
import org.neo4j.gds.ml.nodeClassification.NodeClassificationPredict;
import org.neo4j.gds.ml.pipeline.ImmutablePipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.PredictPipelineExecutor;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineResult;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

public class NodeClassificationPredictPipelineExecutor
extends PredictPipelineExecutor<NodeClassificationPredictPipelineBaseConfig, NodePropertyPredictPipeline, NodeClassificationPipelineResult> {
    private static final int MIN_BATCH_SIZE = 100;
    private final Classifier.ClassifierData modelData;
    private final LocalIdMap classIdMap;
    private final PipelineGraphFilter predictGraphFilter;

    public NodeClassificationPredictPipelineExecutor(NodePropertyPredictPipeline pipeline, NodeClassificationPredictPipelineBaseConfig config, ExecutionContext executionContext, GraphStore graphStore, ProgressTracker progressTracker, Classifier.ClassifierData modelData, LocalIdMap classIdMap) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, progressTracker);
        this.modelData = modelData;
        this.classIdMap = classIdMap;
        this.predictGraphFilter = ImmutablePipelineGraphFilter.builder().nodeLabels(config.nodeLabelIdentifiers(graphStore)).relationshipTypes(config.internalRelationshipTypes(graphStore)).build();
    }

    public static Task progressTask(String taskName, NodePropertyPredictPipeline pipeline, GraphStore graphStore) {
        return Tasks.task((String)taskName, (Task)NodePropertyStepExecutor.tasks((List)pipeline.nodePropertySteps(), (long)graphStore.nodeCount()), (Task[])new Task[]{NodeClassificationPredict.progressTask((long)graphStore.nodeCount())});
    }

    public static MemoryEstimation estimate(Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model, NodeClassificationPredictPipelineBaseConfig configuration, ModelCatalog modelCatalog) {
        NodePropertyPredictPipeline pipeline = ((NodeClassificationPipelineModelInfo)model.customInfo()).pipeline();
        int classCount = ((NodeClassificationPipelineModelInfo)model.customInfo()).classes().size();
        int featureCount = ((Classifier.ClassifierData)model.data()).featureDimension();
        List combinedNodeLabels = configuration.targetNodeLabels().isEmpty() ? ((NodeClassificationPipelineTrainConfig)model.trainConfig()).targetNodeLabels() : configuration.targetNodeLabels();
        List combinedRelationshipTypes = configuration.relationshipTypes().isEmpty() ? ((NodeClassificationPipelineTrainConfig)model.trainConfig()).relationshipTypes() : configuration.relationshipTypes();
        MemoryEstimation nodePropertyStepEstimation = NodePropertyStepExecutor.estimateNodePropertySteps((ModelCatalog)modelCatalog, (String)configuration.username(), (List)pipeline.nodePropertySteps(), (List)combinedNodeLabels, (List)combinedRelationshipTypes);
        MemoryEstimation predictionEstimation = MemoryEstimations.builder().add("Pipeline Predict", NodeClassificationPredict.memoryEstimationWithDerivedBatchSize((TrainingMethod)((Classifier.ClassifierData)model.data()).trainerMethod(), (boolean)configuration.includePredictedProbabilities(), (int)100, (int)featureCount, (int)classCount, (boolean)false)).build();
        return MemoryEstimations.maxEstimation(List.of(nodePropertyStepEstimation, predictionEstimation));
    }

    protected PipelineGraphFilter nodePropertyStepFilter() {
        return this.predictGraphFilter;
    }

    protected NodeClassificationPipelineResult execute() {
        Graph nodesGraph = this.graphStore.getGraph(this.predictGraphFilter.nodeLabels());
        Features features = FeaturesFactory.extractLazyFeatures((Graph)nodesGraph, (List)((NodePropertyPredictPipeline)this.pipeline).featureProperties());
        if (features.featureDimension() != this.modelData.featureDimension()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Model expected features %s to have a total dimension of `%d`, but got `%d`.", (Object[])new Object[]{StringJoining.join((Collection)((NodePropertyPredictPipeline)this.pipeline).featureProperties()), this.modelData.featureDimension(), features.featureDimension()}));
        }
        NodeClassificationPredict.NodeClassificationResult nodeClassificationResult = new NodeClassificationPredict(ClassifierFactory.create((Classifier.ClassifierData)this.modelData), features, 100, ((NodeClassificationPredictPipelineBaseConfig)this.config).concurrency(), ((NodeClassificationPredictPipelineBaseConfig)this.config).includePredictedProbabilities(), this.progressTracker, this.terminationFlag).compute();
        return NodeClassificationPipelineResult.of(nodeClassificationResult, this.classIdMap);
    }
}

