package org.tensorflow.framework.op.math;

import java.util.Arrays;
import java.util.Collections;
import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossTuple;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.AssertThat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Identity;
import org.tensorflow.op.core.OnesLike;
import org.tensorflow.op.core.ReduceAll;
import org.tensorflow.op.core.ReduceMax;
import org.tensorflow.op.core.ScatterNd;
import org.tensorflow.op.core.Squeeze;
import org.tensorflow.op.core.Stack;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.GreaterEqual;
import org.tensorflow.op.math.Less;
import org.tensorflow.op.math.Maximum;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/op/math/ConfusionMatrix.class */
public class ConfusionMatrix {
    public static <T extends TNumber> Operand<T> confusionMatrix(Scope scope, Operand<T> operand, Operand<T> operand2) {
        return confusionMatrix(scope, operand, operand2, null, null);
    }

    public static <T extends TNumber> Operand<T> confusionMatrix(Scope scope, Operand<T> operand, Operand<T> operand2, Operand<T> operand3) {
        return confusionMatrix(scope, operand, operand2, operand3, null);
    }

    public static <T extends TNumber> Operand<T> confusionMatrix(Scope scope, Operand<T> operand, Operand<T> operand2, Operand<T> operand3, Operand<TInt64> operand4) {
        Operand create;
        Scope withSubScope = scope.withSubScope("confusionMatrix");
        LossTuple removeSqueezableDimensions = removeSqueezableDimensions(scope, operand, operand2, 0);
        Cast create2 = Cast.create(withSubScope, removeSqueezableDimensions.getLabels(), TInt64.class, new Cast.Options[0]);
        Cast create3 = Cast.create(withSubScope, removeSqueezableDimensions.getTarget(), TInt64.class, new Cast.Options[0]);
        Constant scalarOf = Constant.scalarOf(withSubScope, 0L);
        Constant scalarOf2 = Constant.scalarOf(withSubScope, 1L);
        Operand create4 = Identity.create(withSubScope.withControlDependencies(Collections.singletonList(AssertThat.create(withSubScope, ReduceAll.create(withSubScope, GreaterEqual.create(withSubScope, create2, scalarOf), Axes.allAxes(scope, create2), new ReduceAll.Options[0]), Collections.singletonList(Constant.scalarOf(withSubScope, "labels contains negative values")), new AssertThat.Options[0]))), create2);
        Operand create5 = Identity.create(withSubScope.withControlDependencies(Collections.singletonList(AssertThat.create(withSubScope, ReduceAll.create(withSubScope, GreaterEqual.create(withSubScope, create3, scalarOf), Axes.allAxes(scope, create3), new ReduceAll.Options[0]), Collections.singletonList(Constant.scalarOf(withSubScope, "predictions contains negative values")), new AssertThat.Options[0]))), create3);
        if (operand4 == null) {
            create = Add.create(withSubScope, Maximum.create(withSubScope, ReduceMax.create(withSubScope, create5, scalarOf, new ReduceMax.Options[0]), ReduceMax.create(withSubScope, create4, scalarOf, new ReduceMax.Options[0])), scalarOf2);
        } else {
            create = Cast.create(withSubScope, operand4, TInt64.class, new Cast.Options[0]);
            Less create6 = Less.create(withSubScope, create4, create);
            create4 = Identity.create(withSubScope.withControlDependencies(Collections.singletonList(AssertThat.create(withSubScope, ReduceAll.create(scope, create6, Axes.allAxes(scope, create6), new ReduceAll.Options[]{ReduceAll.keepDims(false)}), Collections.singletonList(Constant.scalarOf(withSubScope, "labels out of bounds")), new AssertThat.Options[0]))), create4);
            Less create7 = Less.create(withSubScope, create5, create);
            create5 = Identity.create(withSubScope.withControlDependencies(Collections.singletonList(AssertThat.create(withSubScope, ReduceAll.create(scope, create7, Axes.allAxes(scope, create7), new ReduceAll.Options[]{ReduceAll.keepDims(false)}), Collections.singletonList(Constant.scalarOf(withSubScope, "predictions  out of bounds")), new AssertThat.Options[0]))), create5);
        }
        if (operand3 != null && !operand2.shape().isCompatibleWith(operand3.shape())) {
            throw new IllegalArgumentException(String.format("predictions.shape() [%s], is not compatible with weights.shape() [ %s].", operand2.shape(), operand3.shape()));
        }
        return ScatterNd.create(withSubScope, Stack.create(withSubScope, Arrays.asList(create4, create5), new Stack.Options[]{Stack.axis(1L)}), operand3 == null ? OnesLike.create(withSubScope, operand2) : operand3, Stack.create(withSubScope, Arrays.asList(create, create), new Stack.Options[0]));
    }

    private static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(Scope scope, Operand<T> operand, Operand<T> operand2, int i) {
        Scope withSubScope = scope.withSubScope("removeSqueezableDimensions");
        Shape shape = operand2.shape();
        int numDimensions = shape.numDimensions();
        Shape shape2 = operand.shape();
        int numDimensions2 = shape2.numDimensions();
        if (numDimensions == Shape.UNKNOWN_SIZE && numDimensions2 == Shape.UNKNOWN_SIZE) {
            if (numDimensions == Shape.UNKNOWN_SIZE && Shape.isCompatible(shape.size(-1), 1L)) {
                operand2 = Squeeze.create(withSubScope, operand2, new Squeeze.Options[]{Squeeze.axis(Collections.singletonList(-1L))});
            }
            if (numDimensions2 == Shape.UNKNOWN_SIZE && Shape.isCompatible(shape2.size(-1), 1L)) {
                operand = Squeeze.create(withSubScope, operand, new Squeeze.Options[]{Squeeze.axis(Collections.singletonList(-1L))});
            }
            return new LossTuple<>(operand, operand2);
        }
        int i2 = numDimensions - numDimensions2;
        if (i2 == i + 1 && Shape.isCompatible(shape.size(-1), 1L)) {
            operand2 = Squeeze.create(withSubScope, operand2, new Squeeze.Options[0]);
        } else if (i2 == i - 1 && Shape.isCompatible(shape2.size(-1), 1L)) {
            operand = Squeeze.create(withSubScope, operand, new Squeeze.Options[0]);
        }
        return new LossTuple<>(operand, operand2);
    }
}
