package org.tensorflow.framework.metrics.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.initializers.Zeros;
import org.tensorflow.framework.metrics.BaseMetric;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Variable;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.class */
public abstract class SensitivitySpecificityBase<T extends TNumber> extends BaseMetric {
    public static final int DEFAULT_NUM_THRESHOLDS = 200;
    public static final String TRUE_POSITIVES = "TRUE_POSITIVES";
    public static final String FALSE_POSITIVES = "FALSE_POSITIVES";
    public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES";
    public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES";
    protected final int numThresholds;
    protected final float[] thresholds;
    private final String truePositivesName;
    private final String falsePositivesName;
    private final String trueNegativesName;
    private final String falseNegativesName;
    private final Zeros<T> zeros;
    private final Class<T> internalType;
    protected Variable<T> truePositives;
    protected Variable<T> falsePositives;
    protected Variable<T> trueNegatives;
    protected Variable<T> falseNegatives;

    /* JADX INFO: Access modifiers changed from: protected */
    public SensitivitySpecificityBase(String str, int i, long j, Class<T> cls) {
        super(str, j);
        this.zeros = new Zeros<>();
        if (i <= 0) {
            throw new IllegalArgumentException("numThresholds must be > 0.");
        }
        this.internalType = cls;
        this.truePositivesName = getVariableName("TRUE_POSITIVES");
        this.falsePositivesName = getVariableName("FALSE_POSITIVES");
        this.trueNegativesName = getVariableName("TRUE_NEGATIVES");
        this.falseNegativesName = getVariableName("FALSE_NEGATIVES");
        this.numThresholds = i;
        if (this.numThresholds == 1) {
            this.thresholds = new float[]{0.5f};
            return;
        }
        this.thresholds = new float[i];
        for (int i2 = 0; i2 < i - 2; i2++) {
            this.thresholds[i2 + 1] = (i2 + 1.0f) / (i - 1);
        }
        this.thresholds[i - 1] = 1.0f;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.tensorflow.framework.metrics.BaseMetric
    public void init(Ops ops) {
        checkIsGraph(ops);
        if (isInitialized()) {
            return;
        }
        setTF(ops);
        Operand<T> call = this.zeros.call(ops, ops.constant(Shape.of(new long[]{this.numThresholds})), this.internalType);
        if (getTruePositives() == null) {
            this.truePositives = ops.withName(this.truePositivesName).withInitScope().variable(call, new Variable.Options[0]);
        }
        if (getFalsePositives() == null) {
            this.falsePositives = ops.withName(this.falsePositivesName).withInitScope().variable(call, new Variable.Options[0]);
        }
        if (getTrueNegatives() == null) {
            this.trueNegatives = ops.withInitScope().withName(this.trueNegativesName).variable(call, new Variable.Options[0]);
        }
        if (getFalseNegatives() == null) {
            this.falseNegatives = ops.withInitScope().withName(this.falseNegativesName).variable(call, new Variable.Options[0]);
        }
        setInitialized(true);
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric, org.tensorflow.framework.metrics.Metric
    public List<Op> updateStateList(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2, Operand<? extends TNumber> operand3) {
        init(ops);
        Operand cast = CastHelper.cast(ops, operand, this.internalType);
        Operand cast2 = CastHelper.cast(ops, operand2, this.internalType);
        Operand cast3 = operand3 != null ? CastHelper.cast(ops, operand3, this.internalType) : null;
        HashMap hashMap = new HashMap();
        hashMap.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTruePositives());
        hashMap.put(ConfusionMatrixEnum.FALSE_POSITIVES, getFalsePositives());
        hashMap.put(ConfusionMatrixEnum.TRUE_NEGATIVES, getTrueNegatives());
        hashMap.put(ConfusionMatrixEnum.FALSE_NEGATIVES, getFalseNegatives());
        return MetricsHelper.updateConfusionMatrixVariables(ops, hashMap, cast, cast2, ops.constant(this.thresholds), null, null, cast3, false, null);
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public Op resetStates(Ops ops) {
        Operand<T> call = this.zeros.call(ops, ops.constant(Shape.of(new long[]{this.numThresholds})), this.internalType);
        ArrayList arrayList = new ArrayList();
        if (getTruePositives() != null) {
            arrayList.add(ops.withName(this.truePositivesName).assign(getTruePositives(), call, new Assign.Options[0]));
        }
        if (getFalsePositives() != null) {
            arrayList.add(ops.withName(this.falsePositivesName).assign(getFalsePositives(), call, new Assign.Options[0]));
        }
        if (getTrueNegatives() != null) {
            arrayList.add(ops.withName(this.trueNegativesName).assign(getTrueNegatives(), call, new Assign.Options[0]));
        }
        if (getFalseNegatives() != null) {
            arrayList.add(ops.withName(this.falseNegativesName).assign(getFalseNegatives(), call, new Assign.Options[0]));
        }
        return arrayList.size() == 1 ? (Op) arrayList.get(0) : ops.withControlDependencies(arrayList).noOp();
    }

    public Variable<T> getTruePositives() {
        return this.truePositives;
    }

    public Variable<T> getFalsePositives() {
        return this.falsePositives;
    }

    public Variable<T> getTrueNegatives() {
        return this.trueNegatives;
    }

    public Variable<T> getFalseNegatives() {
        return this.falseNegatives;
    }

    public int getNumThresholds() {
        return this.numThresholds;
    }

    public float[] getThresholds() {
        return this.thresholds;
    }

    public String getTruePositivesName() {
        return this.truePositivesName;
    }

    public String getFalsePositivesName() {
        return this.falsePositivesName;
    }

    public String getTrueNegativesName() {
        return this.trueNegativesName;
    }

    public String getFalseNegativesName() {
        return this.falseNegativesName;
    }

    public Class<T> getType() {
        return this.internalType;
    }

    public Class<T> getInternalType() {
        return this.internalType;
    }
}
