package org.tensorflow.framework.metrics;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.initializers.Zeros;
import org.tensorflow.framework.losses.Losses;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum;
import org.tensorflow.framework.metrics.impl.MetricsHelper;
import org.tensorflow.framework.metrics.impl.SymbolicShape;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.AssertThat;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Div;
import org.tensorflow.op.math.DivNoNan;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.math.Sub;
import org.tensorflow.types.TBool;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/AUC.class */
public class AUC<T extends TNumber> extends BaseMetric {
    public static final float EPSILON = 1.0E-7f;
    public static final String TRUE_POSITIVES = "TRUE_POSITIVES";
    public static final String FALSE_POSITIVES = "FALSE_POSITIVES";
    public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES";
    public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES";
    public static final int DEFAULT_NUM_THRESHOLDS = 200;
    public static final String DEFAULT_NAME = "auc";
    private final int numThresholds;
    private final AUCCurve curve;
    private final AUCSummationMethod summationMethod;
    private final float[] thresholds;
    private final boolean multiLabel;
    private final String truePositivesName;
    private final String falsePositivesName;
    private final String trueNegativesName;
    private final String falseNegativesName;
    private final Class<T> type;
    private final Zeros<T> zeros;
    private Integer numLabels;
    private Operand<T> labelWeights;
    private Variable<T> truePositives;
    private Variable<T> falsePositives;
    private Variable<T> trueNegatives;
    private Variable<T> falseNegatives;
    private Shape variableShape;
    private Shape shape;

    /* renamed from: org.tensorflow.framework.metrics.AUC$1, reason: invalid class name */
    /* loaded from: input_file:org/tensorflow/framework/metrics/AUC$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$framework$metrics$AUCCurve;
        static final /* synthetic */ int[] $SwitchMap$org$tensorflow$framework$metrics$AUCSummationMethod = new int[AUCSummationMethod.values().length];

        static {
            try {
                $SwitchMap$org$tensorflow$framework$metrics$AUCSummationMethod[AUCSummationMethod.INTERPOLATION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$metrics$AUCSummationMethod[AUCSummationMethod.MINORING.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$metrics$AUCSummationMethod[AUCSummationMethod.MAJORING.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$tensorflow$framework$metrics$AUCCurve = new int[AUCCurve.values().length];
            try {
                $SwitchMap$org$tensorflow$framework$metrics$AUCCurve[AUCCurve.ROC.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$tensorflow$framework$metrics$AUCCurve[AUCCurve.PR.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    public AUC(long j, Class<T> cls) {
        this(null, 200, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, null, false, null, j, cls);
    }

    public AUC(String str, long j, Class<T> cls) {
        this(str, 200, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, null, false, null, j, cls);
    }

    public AUC(int i, long j, Class<T> cls) {
        this(null, i, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, null, false, null, j, cls);
    }

    public AUC(float[] fArr, long j, Class<T> cls) {
        this(null, 200, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, fArr, false, null, j, cls);
    }

    public AUC(String str, int i, long j, Class<T> cls) {
        this(str, i, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, null, false, null, j, cls);
    }

    public AUC(String str, float[] fArr, long j, Class<T> cls) {
        this(str, 200, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, fArr, false, null, j, cls);
    }

    public AUC(String str, int i, AUCCurve aUCCurve, long j, Class<T> cls) {
        this(str, i, aUCCurve, AUCSummationMethod.INTERPOLATION, null, false, null, j, cls);
    }

    public AUC(String str, float[] fArr, AUCCurve aUCCurve, long j, Class<T> cls) {
        this(str, 200, aUCCurve, AUCSummationMethod.INTERPOLATION, fArr, false, null, j, cls);
    }

    public AUC(int i, AUCCurve aUCCurve, long j, Class<T> cls) {
        this(null, i, aUCCurve, AUCSummationMethod.INTERPOLATION, null, false, null, j, cls);
    }

    public AUC(float[] fArr, AUCCurve aUCCurve, long j, Class<T> cls) {
        this(null, 200, aUCCurve, AUCSummationMethod.INTERPOLATION, fArr, false, null, j, cls);
    }

    public AUC(int i, AUCCurve aUCCurve, AUCSummationMethod aUCSummationMethod, long j, Class<T> cls) {
        this(null, i, aUCCurve, aUCSummationMethod, null, false, null, j, cls);
    }

    public AUC(float[] fArr, AUCCurve aUCCurve, AUCSummationMethod aUCSummationMethod, long j, Class<T> cls) {
        this(null, 200, aUCCurve, aUCSummationMethod, fArr, false, null, j, cls);
    }

    public AUC(String str, int i, AUCCurve aUCCurve, AUCSummationMethod aUCSummationMethod, long j, Class<T> cls) {
        this(str, i, aUCCurve, aUCSummationMethod, null, false, null, j, cls);
    }

    public AUC(String str, float[] fArr, AUCCurve aUCCurve, AUCSummationMethod aUCSummationMethod, long j, Class<T> cls) {
        this(str, 200, aUCCurve, aUCSummationMethod, fArr, false, null, j, cls);
    }

    public AUC(String str, int i, AUCCurve aUCCurve, AUCSummationMethod aUCSummationMethod, float[] fArr, boolean z, Operand<T> operand, long j, Class<T> cls) {
        super(str == null ? DEFAULT_NAME : str, j);
        this.zeros = new Zeros<>();
        this.truePositivesName = getVariableName("TRUE_POSITIVES");
        this.falsePositivesName = getVariableName("FALSE_POSITIVES");
        this.trueNegativesName = getVariableName("TRUE_NEGATIVES");
        this.falseNegativesName = getVariableName("FALSE_NEGATIVES");
        this.curve = aUCCurve;
        this.summationMethod = aUCSummationMethod;
        this.type = cls;
        this.multiLabel = z;
        if (fArr != null) {
            for (float f : fArr) {
                if (f < 0.0f || f > 1.0f) {
                    throw new IllegalArgumentException(String.format("Threshold values must be in range [0, 1], inclusive. Invalid values: %s", Arrays.toString(fArr)));
                }
            }
            this.numThresholds = fArr.length + 2;
            Arrays.sort(fArr);
        } else {
            if (i <= 1) {
                throw new IllegalArgumentException("numThresholds must be > 1.");
            }
            this.numThresholds = i;
            fArr = new float[this.numThresholds - 2];
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i2] = ((i2 + 1) * 1.0f) / (this.numThresholds - 1);
            }
        }
        this.thresholds = new float[this.numThresholds];
        this.thresholds[0] = -1.0E-7f;
        System.arraycopy(fArr, 0, this.thresholds, 1, fArr.length);
        this.thresholds[this.numThresholds - 1] = 1.0000001f;
        this.labelWeights = operand;
        if (z) {
            this.numLabels = null;
        }
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric
    protected void init(Ops ops) {
        checkIsGraph(ops);
        if (this.shape == null || isInitialized()) {
            return;
        }
        setTF(ops);
        if (this.labelWeights != null) {
            this.labelWeights = ops.withSubScope("updateState").withControlDependencies(Collections.singletonList(ops.withSubScope("AUC").assertThat(ops.math.greaterEqual(this.labelWeights, CastHelper.cast(ops, ops.constant(0), this.labelWeights.type())), Collections.singletonList(ops.constant("All values of labelWeights must be non-negative.")), new AssertThat.Options[0]))).identity(this.labelWeights);
        }
        if (!isMultiLabel()) {
            this.variableShape = Shape.of(new long[]{this.numThresholds});
        } else {
            if (this.shape == null) {
                throw new IllegalArgumentException("For multiLabel, a shape must be provided");
            }
            if (this.shape.numDimensions() != 2) {
                throw new IllegalArgumentException(String.format("labels must have rank=2 when multiLabel is true. Found rank %d.", Integer.valueOf(this.shape.numDimensions())));
            }
            this.numLabels = Integer.valueOf((int) this.shape.size(1));
            this.variableShape = Shape.of(new long[]{this.numThresholds, this.numLabels.intValue()});
        }
        Operand<T> call = this.zeros.call(ops, ops.constant(this.variableShape), this.type);
        if (this.truePositives == null) {
            this.truePositives = ops.withName(getTruePositivesName()).withInitScope().variable(call, new Variable.Options[0]);
        }
        if (this.falsePositives == null) {
            this.falsePositives = ops.withName(getFalsePositivesName()).withInitScope().variable(call, new Variable.Options[0]);
        }
        if (this.trueNegatives == null) {
            this.trueNegatives = ops.withName(getTrueNegativesName()).withInitScope().variable(call, new Variable.Options[0]);
        }
        if (this.falseNegatives == null) {
            this.falseNegatives = ops.withName(getFalseNegativesName()).withInitScope().variable(call, new Variable.Options[0]);
        }
        setInitialized(true);
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric, org.tensorflow.framework.metrics.Metric
    public List<Op> updateStateList(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2, Operand<? extends TNumber> operand3) {
        if (this.shape == null) {
            this.shape = operand2.shape();
        }
        init(ops);
        Operand cast = CastHelper.cast(ops, operand, this.type);
        Operand cast2 = CastHelper.cast(ops, operand2, this.type);
        Operand cast3 = operand3 != null ? CastHelper.cast(ops, operand3, this.type) : null;
        ArrayList arrayList = new ArrayList();
        if (isMultiLabel() || getLabelWeights() != null) {
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(new SymbolicShape(cast, "N", "L"));
            if (isMultiLabel()) {
                arrayList2.add(new SymbolicShape(this.truePositives, "T", "L"));
                arrayList2.add(new SymbolicShape(this.falsePositives, "T", "L"));
                arrayList2.add(new SymbolicShape(this.trueNegatives, "T", "L"));
                arrayList2.add(new SymbolicShape(this.falseNegatives, "T", "L"));
            }
            if (getLabelWeights() != null) {
                arrayList2.add(new SymbolicShape(getLabelWeights(), "L"));
            }
            arrayList.addAll(MetricsHelper.assertShapes(ops, arrayList2, "Number of labels is not consistent."));
        }
        HashMap hashMap = new HashMap();
        hashMap.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives);
        hashMap.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.falsePositives);
        hashMap.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.trueNegatives);
        hashMap.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives);
        arrayList.addAll(MetricsHelper.updateConfusionMatrixVariables(ops, hashMap, cast, cast2, ops.constant(this.thresholds), null, null, cast3, isMultiLabel(), isMultiLabel() ? null : getLabelWeights()));
        return arrayList;
    }

    private Operand<T> positive(Ops ops, Operand<T> operand) {
        return ops.math.maximum(operand, CastHelper.cast(ops, ops.constant(0), operand.type()));
    }

    private Operand<TBool> isPositive(Ops ops, Operand<T> operand) {
        return ops.math.greater(operand, CastHelper.cast(ops, ops.constant(0), operand.type()));
    }

    private Operand<T> slice(Ops ops, Operand<T> operand, int i, int i2) {
        return ops.slice(operand, ops.constant(new int[]{i}), ops.constant(new int[]{i2}));
    }

    private Operand<T> interpolatePRAuc(Ops ops) {
        Operand<T> slice = slice(ops, this.truePositives, 0, getNumThresholds() - 1);
        Operand<T> slice2 = slice(ops, this.truePositives, 1, -1);
        Sub sub = ops.math.sub(slice, slice2);
        Add add = ops.math.add(this.truePositives, this.falsePositives);
        Operand<T> slice3 = slice(ops, add, 0, getNumThresholds() - 1);
        Operand<T> slice4 = slice(ops, add, 1, -1);
        DivNoNan divNoNan = ops.math.divNoNan(sub, positive(ops, ops.math.sub(slice3, slice4)));
        DivNoNan divNoNan2 = ops.math.divNoNan(ops.math.mul(divNoNan, ops.math.add(sub, ops.math.mul(ops.math.sub(slice2, ops.math.mul(divNoNan, slice4)), ops.math.log(ops.select(ops.math.logicalAnd(isPositive(ops, slice3), isPositive(ops, slice4)), ops.math.divNoNan(slice3, positive(ops, slice4)), ops.onesLike(slice4)))))), positive(ops, ops.math.add(slice2, slice(ops, this.falseNegatives, 1, -1))));
        if (!isMultiLabel()) {
            return ops.reduceSum(divNoNan2, LossesHelper.allAxes(ops, divNoNan2), new ReduceSum.Options[0]);
        }
        ReduceSum reduceSum = ops.reduceSum(divNoNan2, ops.constant(0), new ReduceSum.Options[0]);
        return getLabelWeights() == null ? MetricsHelper.mean(ops, reduceSum) : ops.math.divNoNan(ops.reduceSum(ops.math.mul(reduceSum, getLabelWeights()), LossesHelper.allAxes(ops, reduceSum), new ReduceSum.Options[0]), ops.reduceSum(getLabelWeights(), LossesHelper.allAxes(ops, getLabelWeights()), new ReduceSum.Options[0]));
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public <U extends TNumber> Operand<U> result(Ops ops, Class<U> cls) {
        DivNoNan divNoNan;
        DivNoNan divNoNan2;
        Div maximum;
        init(ops);
        if (getCurve() == AUCCurve.PR && getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
            return CastHelper.cast(ops, interpolatePRAuc(ops), cls);
        }
        DivNoNan divNoNan3 = ops.math.divNoNan(this.truePositives, ops.math.add(this.truePositives, this.falseNegatives));
        switch (AnonymousClass1.$SwitchMap$org$tensorflow$framework$metrics$AUCCurve[getCurve().ordinal()]) {
            case Losses.CHANNELS_FIRST /* 1 */:
                divNoNan2 = ops.math.divNoNan(this.falsePositives, ops.math.add(this.falsePositives, this.trueNegatives));
                divNoNan = divNoNan3;
                break;
            case 2:
                divNoNan = ops.math.divNoNan(this.truePositives, ops.math.add(this.truePositives, this.falsePositives));
                divNoNan2 = divNoNan3;
                break;
            default:
                throw new IllegalArgumentException("Unexpected AUCCurve value: " + getCurve());
        }
        Operand<T> slice = slice(ops, divNoNan, 0, getNumThresholds() - 1);
        Operand<T> slice2 = slice(ops, divNoNan, 1, -1);
        switch (AnonymousClass1.$SwitchMap$org$tensorflow$framework$metrics$AUCSummationMethod[getSummationMethod().ordinal()]) {
            case Losses.CHANNELS_FIRST /* 1 */:
                maximum = ops.math.div(ops.math.add(slice, slice2), CastHelper.cast(ops, ops.constant(2), divNoNan.type()));
                break;
            case 2:
                maximum = ops.math.minimum(slice, slice2);
                break;
            case 3:
                maximum = ops.math.maximum(slice, slice2);
                break;
            default:
                throw new IllegalArgumentException("Unexpected AUCSummationMethod value: " + getSummationMethod());
        }
        if (isMultiLabel()) {
            ReduceSum reduceSum = ops.reduceSum(ops.math.mul(ops.math.sub(slice(ops, divNoNan2, 0, getNumThresholds() - 1), slice(ops, divNoNan2, 1, -1)), maximum), ops.constant(0), new ReduceSum.Options[0]);
            return getLabelWeights() == null ? CastHelper.cast(ops, MetricsHelper.mean(ops, reduceSum), cls) : CastHelper.cast(ops, ops.math.divNoNan(ops.reduceSum(ops.math.mul(reduceSum, getLabelWeights()), LossesHelper.allAxes(ops, getLabelWeights()), new ReduceSum.Options[0]), ops.reduceSum(getLabelWeights(), LossesHelper.allAxes(ops, getLabelWeights()), new ReduceSum.Options[0])), cls);
        }
        Mul mul = ops.math.mul(ops.math.sub(slice(ops, divNoNan2, 0, getNumThresholds() - 1), slice(ops, divNoNan2, 1, -1)), maximum);
        return CastHelper.cast(ops, ops.reduceSum(mul, LossesHelper.allAxes(ops, mul), new ReduceSum.Options[0]), cls);
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public Op resetStates(Ops ops) {
        init(ops);
        Operand<T> call = this.zeros.call(ops, ops.constant(this.variableShape), this.type);
        ArrayList arrayList = new ArrayList();
        if (this.truePositives != null) {
            arrayList.add(ops.assign(this.truePositives, call, new Assign.Options[0]));
        }
        if (this.falsePositives != null) {
            arrayList.add(ops.assign(this.falsePositives, call, new Assign.Options[0]));
        }
        if (this.trueNegatives != null) {
            arrayList.add(ops.assign(this.trueNegatives, call, new Assign.Options[0]));
        }
        if (this.falseNegatives != null) {
            arrayList.add(ops.assign(this.falseNegatives, call, new Assign.Options[0]));
        }
        return ops.withControlDependencies(arrayList).noOp();
    }

    public int getNumThresholds() {
        return this.numThresholds;
    }

    public AUCCurve getCurve() {
        return this.curve;
    }

    public AUCSummationMethod getSummationMethod() {
        return this.summationMethod;
    }

    public float[] getThresholds() {
        return this.thresholds;
    }

    public boolean isMultiLabel() {
        return this.multiLabel;
    }

    public Integer getNumLabels() {
        return this.numLabels;
    }

    public void setNumLabels(Integer num) {
        this.numLabels = num;
    }

    public Operand<T> getLabelWeights() {
        return this.labelWeights;
    }

    public Variable<T> getTruePositives() {
        return this.truePositives;
    }

    public Variable<T> getFalsePositives() {
        return this.falsePositives;
    }

    public Variable<T> getTrueNegatives() {
        return this.trueNegatives;
    }

    public Variable<T> getFalseNegatives() {
        return this.falseNegatives;
    }

    public String getTruePositivesName() {
        return this.truePositivesName;
    }

    public String getFalsePositivesName() {
        return this.falsePositivesName;
    }

    public String getTrueNegativesName() {
        return this.trueNegativesName;
    }

    public String getFalseNegativesName() {
        return this.falseNegativesName;
    }
}
