package org.tensorflow.framework.op.nn;

import java.util.Arrays;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Range;
import org.tensorflow.op.core.Rank;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Slice;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.math.Sub;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.class */
public class SoftmaxCrossEntropyWithLogits {
    public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(Scope scope, Operand<U> operand, Operand<T> operand2, int i) {
        Scope withSubScope = scope.withSubScope("SoftmaxCrossEntropyWithLogits");
        int numDimensions = i % operand2.shape().numDimensions();
        if (numDimensions < 0) {
            numDimensions += operand2.shape().numDimensions();
        }
        if (operand2.asOutput().type() == TFloat16.class || operand2.asOutput().type() == TBfloat16.class) {
            return Cast.create(withSubScope, softmaxCrossEntropyWithLogits(withSubScope, Cast.create(withSubScope, operand, TFloat32.class, new Cast.Options[0]), Cast.create(withSubScope, operand2, TFloat32.class, new Cast.Options[0]), numDimensions), operand2.asOutput().type(), new Cast.Options[0]);
        }
        if (operand2.asOutput().type() != operand.asOutput().type()) {
            return softmaxCrossEntropyWithLogits(withSubScope, Cast.create(withSubScope, operand, operand2.asOutput().type(), new Cast.Options[0]), operand2, numDimensions);
        }
        Cast create = Cast.create(withSubScope, Rank.create(withSubScope, operand2), TInt64.class, new Cast.Options[0]);
        Shape shape = operand2.shape();
        if (numDimensions != -1 && numDimensions != operand2.shape().numDimensions() - 1) {
            operand2 = moveDimToEnd(withSubScope, operand2, numDimensions, create);
            operand = moveDimToEnd(withSubScope, operand, numDimensions, create);
        }
        Operand<U> create2 = operand.type() != operand2.type() ? Cast.create(withSubScope, operand, operand2.type(), new Cast.Options[0]) : operand;
        Operand create3 = Reshape.create(withSubScope, org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits.create(withSubScope, flattenOuterDims(withSubScope, operand2), flattenOuterDims(withSubScope, create2)).loss(), Slice.create(withSubScope, Constant.tensorOf(withSubScope, operand2.shape()), Constant.arrayOf(withSubScope, new long[]{0}), Constant.arrayOf(withSubScope, new long[]{r0.numDimensions() - 1})));
        if (withSubScope.env().isGraph() && !shape.hasUnknownDimension()) {
            long[] asArray = shape.asArray();
            if (asArray == null) {
                asArray = new long[0];
            }
            long[] jArr = new long[asArray.length - 1];
            if (numDimensions < 0) {
                numDimensions = shape.numDimensions() + numDimensions;
            }
            for (int i2 = 0; i2 < numDimensions; i2++) {
                jArr[i2] = shape.size(i2);
            }
            for (int i3 = numDimensions + 1; i3 < shape.numDimensions(); i3++) {
                jArr[i3 - 1] = shape.size(i3);
            }
            create3 = Reshape.create(withSubScope, create3, Constant.vectorOf(withSubScope, jArr));
        }
        return create3;
    }

    private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> operand) {
        Constant scalarOf = Constant.scalarOf(scope, 1L);
        Shape shape = operand.shape();
        int numDimensions = shape.numDimensions();
        if (!shape.hasUnknownDimension()) {
            long j = 1;
            boolean z = true;
            int i = numDimensions - 2;
            while (true) {
                if (i < 0) {
                    break;
                }
                long size = shape.size(i);
                if (size == Shape.UNKNOWN_SIZE) {
                    z = false;
                    break;
                }
                j *= size;
                i--;
            }
            if (z) {
                return Reshape.create(scope, operand, Constant.arrayOf(scope, new long[]{j, shape.size(-1)}));
            }
        }
        return Reshape.create(scope, operand, Concat.create(scope, Arrays.asList(Constant.arrayOf(scope, new long[]{-1}), Slice.create(scope, org.tensorflow.op.core.Shape.create(scope, operand, TInt64.class), Sub.create(scope, Cast.create(scope, Rank.create(scope, operand), TInt64.class, new Cast.Options[0]), scalarOf), scalarOf)), Constant.scalarOf(scope, 0)));
    }

    private static <T extends TNumber, U extends TNumber> Operand<T> moveDimToEnd(Scope scope, Operand<T> operand, int i, Operand<U> operand2) {
        Class type = operand2.asOutput().type();
        Cast create = Cast.create(scope, Constant.scalarOf(scope, 1), type, new Cast.Options[0]);
        return Transpose.create(scope, operand, Concat.create(scope, Arrays.asList(Range.create(scope, Cast.create(scope, Constant.scalarOf(scope, i), type, new Cast.Options[0]), create, create), Range.create(scope, Cast.create(scope, Constant.scalarOf(scope, i + 1), type, new Cast.Options[0]), create, create)), Constant.scalarOf(scope, 0)));
    }
}
