package org.tribuo.classification.evaluation;

import com.oracle.labs.mlrg.olcut.util.SortUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/classification/evaluation/LabelEvaluationUtil.class */
public final class LabelEvaluationUtil {

    /* loaded from: input_file:org/tribuo/classification/evaluation/LabelEvaluationUtil$PRCurve.class */
    public static class PRCurve {
        public final double[] precision;
        public final double[] recall;
        public final double[] thresholds;

        public PRCurve(double[] dArr, double[] dArr2, double[] dArr3) {
            this.precision = dArr;
            this.recall = dArr2;
            this.thresholds = dArr3;
        }
    }

    /* loaded from: input_file:org/tribuo/classification/evaluation/LabelEvaluationUtil$ROC.class */
    public static class ROC {
        public final double[] fpr;
        public final double[] tpr;
        public final double[] thresholds;

        public ROC(double[] dArr, double[] dArr2, double[] dArr3) {
            this.fpr = dArr;
            this.tpr = dArr2;
            this.thresholds = dArr3;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/classification/evaluation/LabelEvaluationUtil$TPFP.class */
    public static class TPFP {
        public final List<Integer> falsePos;
        public final List<Integer> truePos;
        public final List<Double> thresholds;
        public final int totalPos;

        public TPFP(List<Integer> list, List<Integer> list2, List<Double> list3, int i) {
            this.falsePos = list;
            this.truePos = list2;
            this.thresholds = list3;
            this.totalPos = i;
        }
    }

    private LabelEvaluationUtil() {
    }

    public static double averagedPrecision(boolean[] zArr, double[] dArr) {
        PRCurve generatePRCurve = generatePRCurve(zArr, dArr);
        double d = 0.0d;
        for (int i = 0; i < generatePRCurve.precision.length - 1; i++) {
            d += (generatePRCurve.recall[i + 1] - generatePRCurve.recall[i]) * generatePRCurve.precision[i];
        }
        return -d;
    }

    public static PRCurve generatePRCurve(boolean[] zArr, double[] dArr) {
        TPFP generateTPFPs = generateTPFPs(zArr, dArr);
        ArrayList arrayList = new ArrayList(generateTPFPs.falsePos.size());
        ArrayList arrayList2 = new ArrayList(generateTPFPs.falsePos.size());
        ArrayList arrayList3 = new ArrayList(generateTPFPs.falsePos.size());
        for (int i = 0; i < generateTPFPs.falsePos.size(); i++) {
            double intValue = generateTPFPs.falsePos.get(i).intValue();
            double intValue2 = generateTPFPs.truePos.get(i).intValue();
            double d = 0.0d;
            double d2 = 0.0d;
            if (intValue2 != 0.0d) {
                d = intValue2 / (intValue2 + intValue);
                d2 = intValue2 / generateTPFPs.totalPos;
            }
            arrayList.add(Double.valueOf(d));
            arrayList2.add(Double.valueOf(d2));
            arrayList3.add(generateTPFPs.thresholds.get(i));
            if (intValue2 == generateTPFPs.totalPos) {
                break;
            }
        }
        Collections.reverse(arrayList);
        Collections.reverse(arrayList2);
        Collections.reverse(arrayList3);
        arrayList.add(Double.valueOf(1.0d));
        arrayList2.add(Double.valueOf(0.0d));
        return new PRCurve(Util.toPrimitiveDouble(arrayList), Util.toPrimitiveDouble(arrayList2), Util.toPrimitiveDouble(arrayList3));
    }

    public static double binaryAUCROC(boolean[] zArr, double[] dArr) {
        ROC generateROCCurve = generateROCCurve(zArr, dArr);
        return Util.auc(generateROCCurve.fpr, generateROCCurve.tpr);
    }

    public static ROC generateROCCurve(boolean[] zArr, double[] dArr) {
        TPFP generateTPFPs = generateTPFPs(zArr, dArr);
        if (generateTPFPs.truePos.get(0).intValue() != 0 || generateTPFPs.falsePos.get(0).intValue() != 0) {
            generateTPFPs.truePos.add(0, 0);
            generateTPFPs.falsePos.add(0, 0);
            generateTPFPs.thresholds.add(0, Double.valueOf(Double.POSITIVE_INFINITY));
        }
        double[] primitiveDoubleFromInteger = Util.toPrimitiveDoubleFromInteger(generateTPFPs.truePos);
        double[] primitiveDoubleFromInteger2 = Util.toPrimitiveDoubleFromInteger(generateTPFPs.falsePos);
        double[] primitiveDouble = Util.toPrimitiveDouble(generateTPFPs.thresholds);
        double d = primitiveDoubleFromInteger[primitiveDoubleFromInteger.length - 1];
        double d2 = primitiveDoubleFromInteger2[primitiveDoubleFromInteger2.length - 1];
        for (int i = 0; i < primitiveDoubleFromInteger.length; i++) {
            int i2 = i;
            primitiveDoubleFromInteger[i2] = primitiveDoubleFromInteger[i2] / d;
            int i3 = i;
            primitiveDoubleFromInteger2[i3] = primitiveDoubleFromInteger2[i3] / d2;
        }
        return new ROC(primitiveDoubleFromInteger2, primitiveDoubleFromInteger, primitiveDouble);
    }

    private static TPFP generateTPFPs(boolean[] zArr, double[] dArr) {
        if (zArr.length != dArr.length) {
            throw new IllegalArgumentException("yPos and yScore must be the same length, yPos.length = " + zArr.length + ", yScore.length = " + dArr.length);
        }
        int[] argsort = SortUtil.argsort(dArr, false);
        double[] dArr2 = new double[dArr.length];
        boolean[] zArr2 = new boolean[zArr.length];
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr2[i2] = dArr[argsort[i2]];
            zArr2[i2] = zArr[argsort[i2]];
            if (zArr2[i2]) {
                i++;
            }
        }
        int[] differencesIndices = Util.differencesIndices(dArr2);
        int[] cumulativeSum = Util.cumulativeSum(zArr2);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i3 = 0; i3 < differencesIndices.length; i3++) {
            arrayList3.add(Double.valueOf(dArr2[differencesIndices[i3]]));
            arrayList.add(Integer.valueOf(cumulativeSum[differencesIndices[i3]]));
            arrayList2.add(Integer.valueOf(1 + (differencesIndices[i3] - cumulativeSum[differencesIndices[i3]])));
        }
        return new TPFP(arrayList2, arrayList, arrayList3, i);
    }
}
