package org.tensorflow.framework.metrics;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.initializers.Zeros;
import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum;
import org.tensorflow.framework.metrics.impl.MetricsHelper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Slice;
import org.tensorflow.op.core.Variable;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/Recall.class */
public class Recall<T extends TNumber> extends BaseMetric {
    public static final float DEFAULT_THRESHOLD = 0.5f;
    public static final String TRUE_POSITIVES = "TRUE_POSITIVES";
    public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES";
    private final float[] thresholds;
    private final Integer topK;
    private final Integer classId;
    private final String truePositivesName;
    private final String falseNegativesName;
    private final Class<T> type;
    private final List<Op> initializers;
    private Variable<T> truePositives;
    private Variable<T> falseNegatives;

    public Recall(long j, Class<T> cls) {
        this((String) null, (float[]) null, (Integer) null, (Integer) null, j, cls);
    }

    public Recall(String str, long j, Class<T> cls) {
        this(str, (float[]) null, (Integer) null, (Integer) null, j, cls);
    }

    public Recall(float f, long j, Class<T> cls) {
        this((String) null, f, (Integer) null, (Integer) null, j, cls);
    }

    public Recall(float[] fArr, long j, Class<T> cls) {
        this((String) null, fArr, (Integer) null, (Integer) null, j, cls);
    }

    public Recall(String str, float f, long j, Class<T> cls) {
        this(str, f, (Integer) null, (Integer) null, j, cls);
    }

    public Recall(String str, float[] fArr, long j, Class<T> cls) {
        this(str, fArr, (Integer) null, (Integer) null, j, cls);
    }

    public Recall(Integer num, Integer num2, long j, Class<T> cls) {
        this((String) null, (float[]) null, num, num2, j, cls);
    }

    public Recall(String str, Integer num, Integer num2, long j, Class<T> cls) {
        this(str, (float[]) null, num, num2, j, cls);
    }

    public Recall(float f, Integer num, Integer num2, long j, Class<T> cls) {
        this((String) null, new float[]{f}, num, num2, j, cls);
    }

    public Recall(float[] fArr, Integer num, Integer num2, long j, Class<T> cls) {
        this((String) null, fArr, num, num2, j, cls);
    }

    public Recall(String str, float f, Integer num, Integer num2, long j, Class<T> cls) {
        this(str, new float[]{f}, num, num2, j, cls);
    }

    public Recall(String str, float[] fArr, Integer num, Integer num2, long j, Class<T> cls) {
        super(str, j);
        this.initializers = new ArrayList();
        this.type = cls;
        this.truePositivesName = getVariableName("TRUE_POSITIVES");
        this.falseNegativesName = getVariableName("FALSE_NEGATIVES");
        this.thresholds = fArr == null ? new float[]{num == null ? 0.5f : -1.0E10f} : fArr;
        this.topK = num;
        this.classId = num2;
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric
    protected void init(Ops ops) {
        checkIsGraph(ops);
        if (isInitialized()) {
            return;
        }
        setTF(ops);
        Operand call = new Zeros().call(ops, ops.constant(Shape.of(new long[]{this.thresholds.length})), this.type);
        if (this.truePositives == null) {
            this.truePositives = ops.withName(this.truePositivesName).withInitScope().variable(call, new Variable.Options[0]);
            this.initializers.add(ops.assign(this.truePositives, call, new Assign.Options[0]));
        }
        if (this.falseNegatives == null) {
            this.falseNegatives = ops.withName(this.falseNegativesName).withInitScope().variable(call, new Variable.Options[0]);
            this.initializers.add(ops.assign(this.falseNegatives, call, new Assign.Options[0]));
        }
        setInitialized(true);
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public Op resetStates(Ops ops) {
        init(ops);
        return ops.withSubScope("resetStates").withControlDependencies(this.initializers).noOp();
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric, org.tensorflow.framework.metrics.Metric
    public List<Op> updateStateList(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2, Operand<? extends TNumber> operand3) {
        init(ops);
        HashMap hashMap = new HashMap();
        hashMap.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives);
        hashMap.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives);
        Operand cast = CastHelper.cast(ops, operand2, this.type);
        return MetricsHelper.updateConfusionMatrixVariables(ops, hashMap, CastHelper.cast(ops, operand, this.type), cast, ops.constant(this.thresholds), this.topK, this.classId, operand3 != null ? CastHelper.cast(ops, operand3, this.type) : null, false, null);
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public <U extends TNumber> Operand<U> result(Ops ops, Class<U> cls) {
        init(ops);
        Slice divNoNan = ops.math.divNoNan(this.truePositives, ops.math.add(this.truePositives, this.falseNegatives));
        return CastHelper.cast(ops, this.thresholds.length == 1 ? ops.slice(divNoNan, ops.constant(new int[]{0}), ops.constant(new int[1])) : divNoNan, cls);
    }

    public float[] getThresholds() {
        return this.thresholds;
    }

    public Integer getTopK() {
        return this.topK;
    }

    public Integer getClassId() {
        return this.classId;
    }

    public Variable<T> getTruePositives() {
        return this.truePositives;
    }

    public Variable<T> getFalseNegatives() {
        return this.falseNegatives;
    }

    public String getTruePositivesName() {
        return this.truePositivesName;
    }

    public String getFalseNegativesName() {
        return this.falseNegativesName;
    }
}
