package org.tensorflow.op.nn;

import java.util.ArrayList;
import java.util.Collections;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.AssertThat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Shapes;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Equal;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.class */
public class SparseSoftmaxCrossEntropyWithLogits {
    public static <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossEntropyWithLogits(Scope scope, Operand<T> operand, Operand<U> operand2) {
        Scope withSubScope = scope.withSubScope(org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.OP_NAME);
        Operand<U> operand3 = operand2;
        if (operand2.asOutput().type() == TFloat16.class || operand2.asOutput().type() == TBfloat16.class) {
            operand3 = Cast.create(withSubScope, operand2, TFloat32.class, new Cast.Options[0]);
        }
        Shape shape = operand.shape();
        org.tensorflow.op.core.Shape<TInt32> create = org.tensorflow.op.core.Shape.create(withSubScope, operand);
        Shape shape2 = operand2.shape();
        Shape take = shape2.take(shape2.numDimensions() - 1);
        boolean z = (shape.hasUnknownDimension() || take.hasUnknownDimension()) ? false : true;
        if (shape2.numDimensions() == 0) {
            throw new IllegalArgumentException(String.format("Logits cannot be scalars - received shape %s.", shape2));
        }
        if (!shape2.hasUnknownDimension() && !shape.hasUnknownDimension() && shape.numDimensions() != shape2.numDimensions() - 1) {
            throw new IllegalArgumentException(String.format("Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).", shape, shape2));
        }
        if (z && !shape.equals(take)) {
            throw new IllegalArgumentException(String.format("Shape mismatch: The shape of labels (received %s) should equal the shape of logits except for the last dimension (received %s).", shape, shape2));
        }
        if (shape2.numDimensions() == 2) {
            Operand loss = org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create(withSubScope, operand3, operand).loss();
            if (operand2.asOutput().type() == TFloat16.class) {
                loss = Cast.create(withSubScope, loss, TFloat16.class, new Cast.Options[0]);
            }
            return loss;
        }
        ArrayList arrayList = new ArrayList();
        if (!z) {
            arrayList.add(AssertThat.create(withSubScope, Equal.create(withSubScope, org.tensorflow.op.core.Shape.create(withSubScope, operand), Shapes.take(withSubScope, org.tensorflow.op.core.Shape.create(withSubScope, operand2), Constant.scalarOf(withSubScope, -1)), new Equal.Options[0]), Collections.singletonList(Constant.scalarOf(withSubScope, "Shape mismatch: The shape of labels  should equal the shape of logits except for the last dimension ")), new AssertThat.Options[0]));
        }
        Reshape create2 = Reshape.create(withSubScope, operand3, Constant.arrayOf(withSubScope, -1, shape2.size(-1)));
        Reshape create3 = Reshape.create(withSubScope, operand, Constant.scalarOf(withSubScope, -1));
        withSubScope.withControlDependencies(arrayList);
        Operand create4 = Reshape.create(withSubScope, org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create(withSubScope, create2, create3).loss(), create);
        if (operand2.asOutput().type() == TFloat16.class) {
            create4 = Cast.create(withSubScope, create4, TFloat16.class, new Cast.Options[0]);
        }
        return create4;
    }
}
