package org.neo4j.gds.ml.nodeClassification;

import java.util.Optional;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionClassifier;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/ml/nodeClassification/NodeClassificationPredict.class */
public class NodeClassificationPredict {
    private final Classifier classifier;
    private final Features features;
    private final boolean produceProbabilities;
    private final ProgressTracker progressTracker;
    private final ParallelNodeClassifier predictor;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/nodeClassification/NodeClassificationPredict$NodeClassificationResult.class */
    public interface NodeClassificationResult {
        HugeIntArray predictedClasses();

        Optional<HugeObjectArray<double[]>> predictedProbabilities();

        static NodeClassificationResult of(HugeIntArray hugeIntArray, @Nullable HugeObjectArray<double[]> hugeObjectArray) {
            return ImmutableNodeClassificationResult.builder().predictedProbabilities(Optional.ofNullable(hugeObjectArray)).predictedClasses(hugeIntArray).build();
        }
    }

    public NodeClassificationPredict(Classifier classifier, Features features, int i, Concurrency concurrency, boolean z, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.classifier = classifier;
        this.features = features;
        this.produceProbabilities = z;
        this.progressTracker = progressTracker;
        this.predictor = new ParallelNodeClassifier(classifier, features, i, concurrency, terminationFlag, progressTracker);
    }

    public static Task progressTask(long j) {
        return Tasks.leaf("Node classification predict", j);
    }

    public static MemoryEstimation memoryEstimation(boolean z, int i, int i2, int i3) {
        MemoryEstimations.Builder builder = MemoryEstimations.builder(NodeClassificationPredict.class.getSimpleName());
        if (z) {
            builder.perNode("predicted probabilities", j -> {
                return HugeObjectArray.memoryEstimation(j, Estimate.sizeOfDoubleArray(i3));
            });
        }
        builder.perNode("predicted classes", HugeLongArray::memoryEstimation);
        builder.fixed("computation graph", LogisticRegressionClassifier.sizeOfPredictionsVariableInBytes(i, i2, i3, i3));
        return builder.build();
    }

    public static MemoryEstimation memoryEstimationWithDerivedBatchSize(TrainingMethod trainingMethod, boolean z, int i, int i2, int i3, boolean z2) {
        MemoryEstimations.Builder builder = MemoryEstimations.builder(NodeClassificationPredict.class.getSimpleName());
        if (z) {
            builder.perNode("predicted probabilities", j -> {
                return HugeObjectArray.memoryEstimation(j, Estimate.sizeOfDoubleArray(i3));
            });
        }
        builder.perNode("predicted classes", HugeLongArray::memoryEstimation);
        builder.perGraphDimension("classifier runtime", (graphDimensions, concurrency) -> {
            return ClassifierFactory.runtimeOverheadMemoryEstimation(trainingMethod, (int) Math.min(graphDimensions.nodeCount(), BatchQueue.computeBatchSize(graphDimensions.nodeCount(), i, concurrency)), i3, i2, z2);
        });
        return builder.build();
    }

    public NodeClassificationResult compute() {
        this.progressTracker.beginSubTask();
        this.progressTracker.setSteps(this.features.size());
        HugeObjectArray<double[]> initProbabilities = initProbabilities();
        HugeIntArray predict = this.predictor.predict(initProbabilities);
        this.progressTracker.endSubTask();
        return NodeClassificationResult.of(predict, initProbabilities);
    }

    @Nullable
    private HugeObjectArray<double[]> initProbabilities() {
        if (!this.produceProbabilities) {
            return null;
        }
        int numberOfClasses = this.classifier.numberOfClasses();
        HugeObjectArray<double[]> newArray = HugeObjectArray.newArray(double[].class, this.features.size());
        newArray.setAll(j -> {
            return new double[numberOfClasses];
        });
        return newArray;
    }
}
