package hex;

import java.util.Arrays;
import java.util.Comparator;
import water.Iced;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Vec;

/* loaded from: input_file:hex/AUC2.class */
public class AUC2 extends Iced {
    public final int _nBins;
    public final double[] _ths;
    public final double[] _tps;
    public final double[] _fps;
    public final double _p;
    public final double _n;
    public final double _auc;
    public final double _gini;
    public final int _max_idx;
    public static final ThresholdCriterion DEFAULT_CM;
    public static final int NBINS = 400;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/AUC2$AUCBuilder.class */
    public static class AUCBuilder extends Iced {
        final int _nBins;
        int _n;
        final double[] _ths;
        final double[] _sqe;
        final double[] _tps;
        final double[] _fps;
        int _ssx = -1;
        static final /* synthetic */ boolean $assertionsDisabled;

        public AUCBuilder(int i) {
            this._nBins = i;
            this._ths = new double[i << 1];
            this._sqe = new double[i << 1];
            this._tps = new double[i << 1];
            this._fps = new double[i << 1];
        }

        public void perRow(double d, int i, double d2) {
            if (!$assertionsDisabled && Double.isNaN(d)) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && i != 0 && i != 1) {
                throw new AssertionError();
            }
            int binarySearch = Arrays.binarySearch(this._ths, 0, this._n, d);
            if (binarySearch >= 0) {
                if (i == 0) {
                    double[] dArr = this._fps;
                    dArr[binarySearch] = dArr[binarySearch] + d2;
                } else {
                    double[] dArr2 = this._tps;
                    dArr2[binarySearch] = dArr2[binarySearch] + d2;
                }
                this._ssx = -1;
                return;
            }
            int i2 = (-binarySearch) - 1;
            if (this._n > this._nBins) {
                int find_smallest = find_smallest();
                double compute_delta_error = compute_delta_error(this._ths[find_smallest + 1], k(find_smallest + 1), this._ths[find_smallest], k(find_smallest));
                double compute_delta_error2 = compute_delta_error(d, d2, this._ths[i2], k(i2));
                double compute_delta_error3 = compute_delta_error(this._ths[i2 + 1], k(i2 + 1), d, d2);
                if (compute_delta_error2 < compute_delta_error || compute_delta_error3 < compute_delta_error) {
                    if (compute_delta_error3 < compute_delta_error2) {
                        i2++;
                    } else {
                        compute_delta_error2 = compute_delta_error3;
                    }
                    double k = k(i2);
                    if (i == 0) {
                        double[] dArr3 = this._fps;
                        int i3 = i2;
                        dArr3[i3] = dArr3[i3] + d2;
                    } else {
                        double[] dArr4 = this._tps;
                        int i4 = i2;
                        dArr4[i4] = dArr4[i4] + d2;
                    }
                    this._ths[i2] = this._ths[i2] + ((d - this._ths[i2]) / k);
                    this._sqe[i2] = this._sqe[i2] + compute_delta_error2;
                    if (!$assertionsDisabled && find_smallest != find_smallest()) {
                        throw new AssertionError();
                    }
                    return;
                }
            }
            if (i2 == this._ssx) {
                this._ssx = -1;
            } else if (i2 < this._ssx) {
                this._ssx++;
            }
            System.arraycopy(this._ths, i2, this._ths, i2 + 1, this._n - i2);
            System.arraycopy(this._sqe, i2, this._sqe, i2 + 1, this._n - i2);
            System.arraycopy(this._tps, i2, this._tps, i2 + 1, this._n - i2);
            System.arraycopy(this._fps, i2, this._fps, i2 + 1, this._n - i2);
            this._ths[i2] = d;
            this._sqe[i2] = 0.0d;
            if (i == 0) {
                this._tps[i2] = 0.0d;
                this._fps[i2] = d2;
            } else {
                this._tps[i2] = d2;
                this._fps[i2] = 0.0d;
            }
            this._n++;
            if (this._n > this._nBins) {
                mergeOneBin();
            }
        }

        public void reduce(AUCBuilder aUCBuilder) {
            int i = this._n - 1;
            int i2 = aUCBuilder._n - 1;
            while (i + i2 + 1 >= 0) {
                boolean z = i2 < 0 || (i >= 0 && this._ths[i] >= aUCBuilder._ths[i2]);
                AUCBuilder aUCBuilder2 = z ? this : aUCBuilder;
                int i3 = z ? i : i2;
                this._ths[i + i2 + 1] = aUCBuilder2._ths[i3];
                this._sqe[i + i2 + 1] = aUCBuilder2._sqe[i3];
                this._tps[i + i2 + 1] = aUCBuilder2._tps[i3];
                this._fps[i + i2 + 1] = aUCBuilder2._fps[i3];
                if (z) {
                    i--;
                } else {
                    i2--;
                }
            }
            this._n += aUCBuilder._n;
            while (true) {
                if (this._n <= this._nBins && !dups()) {
                    return;
                } else {
                    mergeOneBin();
                }
            }
        }

        private void mergeOneBin() {
            int find_smallest = find_smallest();
            double k = k(find_smallest);
            double k2 = k(find_smallest + 1);
            this._ths[find_smallest] = ((this._ths[find_smallest] * k) + (this._ths[find_smallest + 1] * k2)) / (k + k2);
            this._sqe[find_smallest] = this._sqe[find_smallest] + this._sqe[find_smallest + 1] + compute_delta_error(this._ths[find_smallest + 1], k2, this._ths[find_smallest], k);
            double[] dArr = this._tps;
            dArr[find_smallest] = dArr[find_smallest] + this._tps[find_smallest + 1];
            double[] dArr2 = this._fps;
            dArr2[find_smallest] = dArr2[find_smallest] + this._fps[find_smallest + 1];
            System.arraycopy(this._ths, find_smallest + 2, this._ths, find_smallest + 1, (this._n - find_smallest) - 2);
            System.arraycopy(this._sqe, find_smallest + 2, this._sqe, find_smallest + 1, (this._n - find_smallest) - 2);
            System.arraycopy(this._tps, find_smallest + 2, this._tps, find_smallest + 1, (this._n - find_smallest) - 2);
            System.arraycopy(this._fps, find_smallest + 2, this._fps, find_smallest + 1, (this._n - find_smallest) - 2);
            this._n--;
            this._ssx = -1;
        }

        private int find_smallest() {
            if (this._ssx == -1) {
                int find_smallest_impl = find_smallest_impl();
                this._ssx = find_smallest_impl;
                return find_smallest_impl;
            }
            if ($assertionsDisabled || this._ssx == find_smallest_impl()) {
                return this._ssx;
            }
            throw new AssertionError();
        }

        private int find_smallest_impl() {
            double d = Double.MAX_VALUE;
            int i = -1;
            int i2 = this._n;
            for (int i3 = 0; i3 < i2 - 1; i3++) {
                double compute_delta_error = compute_delta_error(this._ths[i3 + 1], k(i3 + 1), this._ths[i3], k(i3));
                if (compute_delta_error == 0.0d) {
                    return i3;
                }
                double d2 = this._sqe[i3] + this._sqe[i3 + 1] + compute_delta_error;
                if (d2 < d) {
                    i = i3;
                    d = d2;
                }
            }
            return i;
        }

        private boolean dups() {
            int i = this._n;
            for (int i2 = 0; i2 < i - 1; i2++) {
                if (compute_delta_error(this._ths[i2 + 1], k(i2 + 1), this._ths[i2], k(i2)) == 0.0d) {
                    this._ssx = i2;
                    return true;
                }
            }
            return false;
        }

        private double compute_delta_error(double d, double d2, double d3, double d4) {
            double d5 = ((float) d) - ((float) d3);
            return (((d5 * d5) * d4) * d2) / (d4 + d2);
        }

        private double k(int i) {
            return this._tps[i] + this._fps[i];
        }

        static {
            $assertionsDisabled = !AUC2.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/AUC2$AUC_Impl.class */
    private static class AUC_Impl extends MRTask<AUC_Impl> {
        final int _nBins;
        AUCBuilder _bldr;

        AUC_Impl(int i) {
            this._nBins = i;
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            AUCBuilder aUCBuilder = new AUCBuilder(this._nBins);
            this._bldr = aUCBuilder;
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i) && !chunk2.isNA(i)) {
                    aUCBuilder.perRow(chunk.atd(i), (int) chunk2.at8(i), 1.0d);
                }
            }
        }

        @Override // water.MRTask
        public void reduce(AUC_Impl aUC_Impl) {
            this._bldr.reduce(aUC_Impl._bldr);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/AUC2$Pair.class */
    public static class Pair {
        final double _prob;
        final byte _act;

        Pair(double d, byte b) {
            this._prob = d;
            this._act = b;
        }
    }

    /* loaded from: input_file:hex/AUC2$ThresholdCriterion.class */
    public enum ThresholdCriterion {
        f1(false) { // from class: hex.AUC2.ThresholdCriterion.1
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                double exec = precision.exec(d, d2, d3, d4);
                double exec2 = tpr.exec(d, d2, d3, d4);
                return (2.0d * (exec * exec2)) / (exec + exec2);
            }
        },
        f2(false) { // from class: hex.AUC2.ThresholdCriterion.2
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                double exec = precision.exec(d, d2, d3, d4);
                double exec2 = tpr.exec(d, d2, d3, d4);
                return (5.0d * (exec * exec2)) / ((4.0d * exec) + exec2);
            }
        },
        f0point5(false) { // from class: hex.AUC2.ThresholdCriterion.3
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                double exec = precision.exec(d, d2, d3, d4);
                double exec2 = tpr.exec(d, d2, d3, d4);
                return (1.25d * (exec * exec2)) / ((0.25d * exec) + exec2);
            }
        },
        accuracy(false) { // from class: hex.AUC2.ThresholdCriterion.4
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return (d4 + d) / (((d + d3) + d4) + d2);
            }
        },
        precision(false) { // from class: hex.AUC2.ThresholdCriterion.5
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d / (d + d2);
            }
        },
        recall(false) { // from class: hex.AUC2.ThresholdCriterion.6
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d / (d + d3);
            }
        },
        specificity(false) { // from class: hex.AUC2.ThresholdCriterion.7
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d4 / (d4 + d2);
            }
        },
        absolute_MCC(false) { // from class: hex.AUC2.ThresholdCriterion.8
            static final /* synthetic */ boolean $assertionsDisabled;

            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                double d5 = (d * d4) - (d2 * d3);
                if (d5 == 0.0d) {
                    return 0.0d;
                }
                double sqrt = d5 / Math.sqrt((((d + d2) * (d + d3)) * (d4 + d2)) * (d4 + d3));
                if ($assertionsDisabled || Math.abs(sqrt) <= 1.0d) {
                    return Math.abs(sqrt);
                }
                throw new AssertionError(d + " " + d2 + " " + d3 + " " + d4);
            }

            static {
                $assertionsDisabled = !AUC2.class.desiredAssertionStatus();
            }
        },
        min_per_class_accuracy(false) { // from class: hex.AUC2.ThresholdCriterion.9
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return Math.min(d / (d + d3), d4 / (d4 + d2));
            }
        },
        mean_per_class_accuracy(false) { // from class: hex.AUC2.ThresholdCriterion.10
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return 0.5d * ((d / (d + d3)) + (d4 / (d4 + d2)));
            }
        },
        tns(true) { // from class: hex.AUC2.ThresholdCriterion.11
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d4;
            }
        },
        fns(true) { // from class: hex.AUC2.ThresholdCriterion.12
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d3;
            }
        },
        fps(true) { // from class: hex.AUC2.ThresholdCriterion.13
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d2;
            }
        },
        tps(true) { // from class: hex.AUC2.ThresholdCriterion.14
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d;
            }
        },
        tnr(false) { // from class: hex.AUC2.ThresholdCriterion.15
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d4 / (d2 + d4);
            }
        },
        fnr(false) { // from class: hex.AUC2.ThresholdCriterion.16
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d3 / (d3 + d);
            }
        },
        fpr(false) { // from class: hex.AUC2.ThresholdCriterion.17
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d2 / (d2 + d4);
            }
        },
        tpr(false) { // from class: hex.AUC2.ThresholdCriterion.18
            @Override // hex.AUC2.ThresholdCriterion
            double exec(double d, double d2, double d3, double d4) {
                return d / (d + d3);
            }
        };

        public final boolean _isInt;
        public static final ThresholdCriterion[] VALUES = values();

        ThresholdCriterion(boolean z) {
            this._isInt = z;
        }

        abstract double exec(double d, double d2, double d3, double d4);

        public double exec(AUC2 auc2, int i) {
            return exec(auc2.tp(i), auc2.fp(i), auc2.fn(i), auc2.tn(i));
        }

        public double max_criterion(AUC2 auc2) {
            return exec(auc2, max_criterion_idx(auc2));
        }

        public int max_criterion_idx(AUC2 auc2) {
            double d = -1.7976931348623157E308d;
            int i = -1;
            for (int i2 = 0; i2 < auc2._nBins; i2++) {
                double exec = exec(auc2, i2);
                if (exec > d) {
                    d = exec;
                    i = i2;
                }
            }
            return i;
        }
    }

    public double threshold(int i) {
        return this._ths[i];
    }

    public double tp(int i) {
        return this._tps[i];
    }

    public double fp(int i) {
        return this._fps[i];
    }

    public double tn(int i) {
        return this._n - this._fps[i];
    }

    public double fn(int i) {
        return this._p - this._tps[i];
    }

    public double maxF1() {
        return ThresholdCriterion.f1.max_criterion(this);
    }

    public AUC2(Vec vec, Vec vec2) {
        this(NBINS, vec, vec2);
    }

    AUC2(int i, Vec vec, Vec vec2) {
        this(new AUC_Impl(i).doAll(vec, vec2)._bldr);
    }

    public AUC2(AUCBuilder aUCBuilder) {
        this._nBins = aUCBuilder._n;
        if (!$assertionsDisabled && this._nBins < 1) {
            throw new AssertionError("Must have >= 1 bins for AUC calculation, but got " + this._nBins);
        }
        this._ths = Arrays.copyOf(aUCBuilder._ths, this._nBins);
        this._tps = Arrays.copyOf(aUCBuilder._tps, this._nBins);
        this._fps = Arrays.copyOf(aUCBuilder._fps, this._nBins);
        for (int i = 0; i < (this._nBins >> 1); i++) {
            double d = this._ths[i];
            this._ths[i] = this._ths[(this._nBins - 1) - i];
            this._ths[(this._nBins - 1) - i] = d;
            double d2 = this._tps[i];
            this._tps[i] = this._tps[(this._nBins - 1) - i];
            this._tps[(this._nBins - 1) - i] = d2;
            double d3 = this._fps[i];
            this._fps[i] = this._fps[(this._nBins - 1) - i];
            this._fps[(this._nBins - 1) - i] = d3;
        }
        double d4 = 0.0d;
        double d5 = 0.0d;
        for (int i2 = 0; i2 < this._nBins; i2++) {
            d4 += this._tps[i2];
            this._tps[i2] = d4;
            d5 += this._fps[i2];
            this._fps[i2] = d5;
        }
        this._p = d4;
        this._n = d5;
        this._auc = compute_auc();
        this._gini = (2.0d * this._auc) - 1.0d;
        this._max_idx = DEFAULT_CM.max_criterion_idx(this);
    }

    private double compute_auc() {
        if (this._fps[this._nBins - 1] == 0.0d) {
            return 1.0d;
        }
        if (this._tps[this._nBins - 1] == 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < this._nBins; i++) {
            d3 += ((this._fps[i] - d2) * (this._tps[i] + d)) / 2.0d;
            d = this._tps[i];
            d2 = this._fps[i];
        }
        return (d3 / this._p) / this._n;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    public double[][] buildCM(int i) {
        return new double[]{new double[]{tn(i), fp(i)}, new double[]{fn(i), tp(i)}};
    }

    public double[][] defaultCM() {
        return this._max_idx == -1 ? (double[][]) null : buildCM(this._max_idx);
    }

    public double defaultThreshold() {
        if (this._max_idx == -1) {
            return 0.5d;
        }
        return this._ths[this._max_idx];
    }

    public double defaultErr() {
        if (this._max_idx == -1) {
            return Double.NaN;
        }
        return (fp(this._max_idx) + fn(this._max_idx)) / (this._p + this._n);
    }

    public static double perfectAUC(Vec vec, Vec vec2) {
        if (vec2.min() < 0.0d || vec2.max() > 1.0d || !vec2.isInt()) {
            throw new IllegalArgumentException("Actuals are either 0 or 1");
        }
        if (vec.min() < 0.0d || vec.max() > 1.0d) {
            throw new IllegalArgumentException("Probabilities are between 0 and 1");
        }
        Pair[] pairArr = new Pair[(int) vec.length()];
        vec.getClass();
        Vec.Reader reader = new Vec.Reader();
        vec2.getClass();
        Vec.Reader reader2 = new Vec.Reader();
        for (int i = 0; i < pairArr.length; i++) {
            pairArr[i] = new Pair(reader.at(i), (byte) reader2.at8(i));
        }
        return perfectAUC(pairArr);
    }

    public static double perfectAUC(double[] dArr, double[] dArr2) {
        Pair[] pairArr = new Pair[dArr.length];
        for (int i = 0; i < pairArr.length; i++) {
            pairArr[i] = new Pair(dArr[i], (byte) dArr2[i]);
        }
        return perfectAUC(pairArr);
    }

    private static double perfectAUC(Pair[] pairArr) {
        Arrays.sort(pairArr, new Comparator<Pair>() { // from class: hex.AUC2.1
            @Override // java.util.Comparator
            public int compare(Pair pair, Pair pair2) {
                if (pair._prob < pair2._prob) {
                    return 1;
                }
                if (pair._prob == pair2._prob) {
                    return pair2._act - pair._act;
                }
                return -1;
            }
        });
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        double d = 1.0d;
        double d2 = 0.0d;
        for (Pair pair : pairArr) {
            if (pair._prob != d) {
                d2 += ((i4 - i2) * (i3 + i)) / 2.0d;
                i = i3;
                i2 = i4;
                d = pair._prob;
            }
            if (pair._act == 1) {
                i3++;
            } else {
                i4++;
            }
        }
        return (((d2 + (i * (i4 - i2))) + (((i3 - i) * (i4 - i2)) / 2.0d)) / i3) / i4;
    }

    static {
        $assertionsDisabled = !AUC2.class.desiredAssertionStatus();
        DEFAULT_CM = ThresholdCriterion.f1;
    }
}
