package org.tensorflow.framework.losses;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/losses/SparseCategoricalCrossentropy.class */
public class SparseCategoricalCrossentropy extends Loss {
    public static final boolean FROM_LOGITS_DEFAULT = false;
    public static final int AXIS_DEFAULT = -1;
    private final boolean fromLogits;
    private final int axis;

    public SparseCategoricalCrossentropy(Ops ops) {
        this(ops, null, false, REDUCTION_DEFAULT, -1);
    }

    public SparseCategoricalCrossentropy(Ops ops, String str) {
        this(ops, str, false, REDUCTION_DEFAULT, -1);
    }

    public SparseCategoricalCrossentropy(Ops ops, Reduction reduction) {
        this(ops, null, false, reduction, -1);
    }

    public SparseCategoricalCrossentropy(Ops ops, String str, Reduction reduction) {
        this(ops, str, false, reduction, -1);
    }

    public SparseCategoricalCrossentropy(Ops ops, String str, boolean z) {
        this(ops, str, z, REDUCTION_DEFAULT, -1);
    }

    public SparseCategoricalCrossentropy(Ops ops, boolean z) {
        this(ops, null, z, REDUCTION_DEFAULT, -1);
    }

    public SparseCategoricalCrossentropy(Ops ops, boolean z, Reduction reduction) {
        this(ops, null, z, reduction, -1);
    }

    public SparseCategoricalCrossentropy(Ops ops, String str, boolean z, Reduction reduction, int i) {
        super(ops, str, reduction);
        this.fromLogits = z;
        this.axis = i;
    }

    @Override // org.tensorflow.framework.losses.Loss
    public <T extends TNumber> Operand<T> call(Operand<? extends TNumber> operand, Operand<T> operand2, Operand<T> operand3) {
        return LossesHelper.computeWeightedLoss(getTF(), Losses.sparseCategoricalCrossentropy(getTF(), operand, !this.fromLogits ? LossesHelper.rangeCheck(getTF(), "predictions range check [0-1]", operand2, CastHelper.cast(getTF(), getTF().constant(0), operand2.type()), CastHelper.cast(getTF(), getTF().constant(1), operand2.type())) : operand2, this.fromLogits, this.axis), getReduction(), operand3);
    }
}
