package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.linkmodels.SignedProbabilities;
import org.neo4j.gds.ml.linkmodels.metrics.LinkMetric;
import org.neo4j.gds.models.Classifier;
import org.neo4j.gds.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionEvaluationMetricComputer.class */
public final class LinkPredictionEvaluationMetricComputer {
    private LinkPredictionEvaluationMetricComputer() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryRange estimate(long j) {
        return MemoryRange.of(SignedProbabilities.estimateMemory(j));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<LinkMetric, Double> computeMetric(FeaturesAndLabels featuresAndLabels, Classifier classifier, BatchQueue batchQueue, LinkPredictionTrainConfig linkPredictionTrainConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        progressTracker.setVolume(featuresAndLabels.size());
        SignedProbabilities create = SignedProbabilities.create(featuresAndLabels.size());
        HugeLongArray labels = featuresAndLabels.labels();
        Features features = featuresAndLabels.features();
        int mapped = classifier.classIdMap().toMapped(1L);
        batchQueue.parallelConsume(linkPredictionTrainConfig.concurrency(), i -> {
            return batch -> {
                Matrix predictProbabilities = classifier.predictProbabilities(batch, features);
                int i = 0;
                for (Long l : batch.nodeIds()) {
                    double dataAt = predictProbabilities.dataAt(i, mapped);
                    i++;
                    create.add((((double) labels.get(l.longValue())) > 1.0d ? 1 : (((double) labels.get(l.longValue())) == 1.0d ? 0 : -1)) == 0 ? dataAt : (-1.0d) * dataAt);
                }
                progressTracker.logProgress(batch.size());
            };
        }, terminationFlag);
        return (Map) linkPredictionTrainConfig.metrics().stream().collect(Collectors.toMap(Function.identity(), linkMetric -> {
            return Double.valueOf(linkMetric.compute(create, linkPredictionTrainConfig.negativeClassWeight()));
        }));
    }
}
