package org.tensorflow.framework.metrics;

import java.util.Collections;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.initializers.Zeros;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.metrics.impl.MetricsHelper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.AssignAdd;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.MatrixDiagPart;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.DivNoNan;
import org.tensorflow.op.math.NotEqual;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/MeanIoU.class */
public class MeanIoU<T extends TNumber> extends BaseMetric {
    public static final String TOTAL_CONFUSION_MATRIX = "TOTAL_CONFUSION_MATRIX";
    private final String totalCMName;
    private final Class<T> type;
    private final Zeros<T> zeros;
    private final long numClasses;
    private final Shape variableShape;
    private Assign<T> initializer;
    private Variable<T> totalConfusionMatrix;

    protected MeanIoU(long j, long j2, Class<T> cls) {
        this(null, j, j2, cls);
    }

    protected MeanIoU(String str, long j, long j2, Class<T> cls) {
        super(str, j2);
        this.zeros = new Zeros<>();
        this.type = cls;
        this.totalCMName = getVariableName(TOTAL_CONFUSION_MATRIX);
        this.numClasses = j;
        this.variableShape = Shape.of(new long[]{j, j});
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric
    protected void init(Ops ops) {
        checkIsGraph(ops);
        if (isInitialized()) {
            return;
        }
        setTF(ops);
        Operand<T> call = this.zeros.call(ops, ops.constant(this.variableShape), this.type);
        this.totalConfusionMatrix = ops.withName(this.totalCMName).withInitScope().variable(call, new Variable.Options[0]);
        this.initializer = ops.assign(this.totalConfusionMatrix, call, new Assign.Options[0]);
        setInitialized(true);
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public Op resetStates(Ops ops) {
        init(ops);
        return ops.withName(this.totalCMName).assign(this.totalConfusionMatrix, this.zeros.call(ops, ops.constant(this.variableShape), this.type), new Assign.Options[0]);
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric, org.tensorflow.framework.metrics.Metric
    public List<Op> updateStateList(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2, Operand<? extends TNumber> operand3) {
        init(ops);
        if (operand3 != null) {
            long numDimensions = operand3.shape().numDimensions();
            long numDimensions2 = operand.shape().numDimensions();
            if (numDimensions != 0 && numDimensions != Shape.UNKNOWN_SIZE && numDimensions2 != Shape.UNKNOWN_SIZE && numDimensions != numDimensions2) {
                throw new IllegalArgumentException(String.format("Weights must either have rank 0, or the same rank as labels, weights rank = %d, labels rank = %d", Long.valueOf(numDimensions), Long.valueOf(numDimensions2)));
            }
        }
        long size = operand.shape().size();
        long size2 = operand2.shape().size();
        if (size != size2) {
            throw new IllegalArgumentException(String.format("labels and predictions must have the same size, labels size = %d, predictions size = %d", Long.valueOf(size), Long.valueOf(size2)));
        }
        Operand cast = CastHelper.cast(ops, operand, this.type);
        if (cast.shape().numDimensions() > 1) {
            cast = ops.shape.flatten(cast);
        }
        Operand cast2 = CastHelper.cast(ops, operand2, this.type);
        if (cast2.shape().numDimensions() > 1) {
            cast2 = ops.shape.flatten(cast2);
        }
        Operand cast3 = operand3 != null ? CastHelper.cast(ops, operand3, this.type) : null;
        if (cast3 != null && cast3.shape().numDimensions() > 1) {
            cast3 = ops.shape.flatten(cast3);
        }
        return Collections.singletonList(ops.assignAdd(this.totalConfusionMatrix, MetricsHelper.confusionMatrix(ops, cast, cast2, ops.constant(this.numClasses), cast3, this.type), new AssignAdd.Options[0]));
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public <U extends TNumber> Operand<U> result(Ops ops, Class<U> cls) {
        init(ops);
        ReduceSum reduceSum = ops.reduceSum(this.totalConfusionMatrix, ops.constant(0), new ReduceSum.Options[0]);
        ReduceSum reduceSum2 = ops.reduceSum(this.totalConfusionMatrix, ops.constant(1), new ReduceSum.Options[0]);
        MatrixDiagPart matrixDiagPart = ops.linalg.matrixDiagPart(this.totalConfusionMatrix, ops.constant(0), CastHelper.cast(ops, ops.constant(0), this.totalConfusionMatrix.type()));
        Add add = ops.math.add(reduceSum, ops.math.sub(reduceSum2, matrixDiagPart));
        ReduceSum reduceSum3 = ops.reduceSum(ops.dtypes.cast(ops.math.notEqual(add, CastHelper.cast(ops, ops.constant(0), add.type()), new NotEqual.Options[0]), this.type, new Cast.Options[0]), LossesHelper.allAxes(ops, add), new ReduceSum.Options[0]);
        DivNoNan divNoNan = ops.math.divNoNan(matrixDiagPart, add);
        return CastHelper.cast(ops, ops.math.divNoNan(ops.reduceSum(divNoNan, LossesHelper.allAxes(ops, divNoNan), new ReduceSum.Options[0]), reduceSum3), cls);
    }
}
