package org.neo4j.gds.ml.metrics.classification;

import com.carrotsearch.hppc.BitSet;
import java.util.Comparator;
import java.util.Optional;
import java.util.concurrent.atomic.LongAdder;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.negativeSampling.NegativeSampler;

/* loaded from: input_file:org/neo4j/gds/ml/metrics/classification/OutOfBagError.class */
public final class OutOfBagError implements Metric {
    public static final OutOfBagError OUT_OF_BAG_ERROR = new OutOfBagError();

    private OutOfBagError() {
    }

    @Override // org.neo4j.gds.ml.metrics.Metric
    public boolean isModelSpecific() {
        return true;
    }

    public static void addPredictionsForTree(DecisionTreePredictor<Integer> decisionTreePredictor, int i, Features features, ReadOnlyHugeLongArray readOnlyHugeLongArray, BitSet bitSet, HugeAtomicLongArray hugeAtomicLongArray) {
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= readOnlyHugeLongArray.size()) {
                return;
            }
            if (!bitSet.get(j2)) {
                hugeAtomicLongArray.getAndAdd((j2 * i) + decisionTreePredictor.predict(features.get(readOnlyHugeLongArray.get(j2))).intValue(), 1L);
            }
            j = j2 + 1;
        }
    }

    public static double evaluate(ReadOnlyHugeLongArray readOnlyHugeLongArray, int i, HugeIntArray hugeIntArray, Concurrency concurrency, HugeAtomicLongArray hugeAtomicLongArray) {
        LongAdder longAdder = new LongAdder();
        LongAdder longAdder2 = new LongAdder();
        RunWithConcurrency.builder().concurrency(concurrency).tasks(PartitionUtils.rangePartition(concurrency, readOnlyHugeLongArray.size(), partition -> {
            return accumulationTask(partition, i, readOnlyHugeLongArray, hugeAtomicLongArray, hugeIntArray, longAdder, longAdder2);
        }, Optional.empty())).run();
        return longAdder2.longValue() == 0 ? NegativeSampler.NEGATIVE : longAdder.doubleValue() / longAdder2.doubleValue();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Runnable accumulationTask(Partition partition, int i, ReadOnlyHugeLongArray readOnlyHugeLongArray, HugeAtomicLongArray hugeAtomicLongArray, HugeIntArray hugeIntArray, LongAdder longAdder, LongAdder longAdder2) {
        return () -> {
            long j = 0;
            long j2 = 0;
            long startNode = partition.startNode();
            long nodeCount = startNode + partition.nodeCount();
            long j3 = startNode;
            while (true) {
                long j4 = j3;
                if (j4 >= nodeCount) {
                    longAdder.add(j);
                    longAdder2.add(j2);
                    return;
                }
                long j5 = j4 * i;
                long j6 = 0;
                int i2 = 0;
                for (int i3 = 0; i3 < i; i3++) {
                    long j7 = hugeAtomicLongArray.get(j5 + i3);
                    if (j7 > j6) {
                        j6 = j7;
                        i2 = i3;
                    }
                }
                if (j6 != 0) {
                    j2++;
                    if (i2 != hugeIntArray.get(readOnlyHugeLongArray.get(j4))) {
                        j++;
                    }
                }
                j3 = j4 + 1;
            }
        };
    }

    @Override // org.neo4j.gds.ml.metrics.Metric
    public String name() {
        return "OUT_OF_BAG_ERROR";
    }

    public String toString() {
        return name();
    }

    @Override // org.neo4j.gds.ml.metrics.Metric
    public Comparator<Double> comparator() {
        return Comparator.naturalOrder();
    }
}
