package org.tensorflow.framework.losses;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.AbstractLoss;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/losses/BinaryCrossentropy.class */
public class BinaryCrossentropy extends AbstractLoss {
    public static final boolean FROM_LOGITS_DEFAULT = false;
    public static final float LABEL_SMOOTHING_DEFAULT = 0.0f;
    private final boolean fromLogits;
    private final float labelSmoothing;

    public BinaryCrossentropy() {
        this(null, false, 0.0f, REDUCTION_DEFAULT);
    }

    public BinaryCrossentropy(Reduction reduction) {
        this(null, false, 0.0f, reduction);
    }

    public BinaryCrossentropy(boolean z) {
        this(null, z, 0.0f, REDUCTION_DEFAULT);
    }

    public BinaryCrossentropy(String str, boolean z) {
        this(str, z, 0.0f, REDUCTION_DEFAULT);
    }

    public BinaryCrossentropy(boolean z, float f) {
        this(null, z, f, REDUCTION_DEFAULT);
    }

    public BinaryCrossentropy(String str, boolean z, float f) {
        this(str, z, f, REDUCTION_DEFAULT);
    }

    public BinaryCrossentropy(boolean z, float f, Reduction reduction) {
        this(null, z, f, reduction);
    }

    public BinaryCrossentropy(String str, boolean z, float f, Reduction reduction) {
        super(str, reduction);
        if (f < 0.0f || f > 1.0f) {
            throw new IllegalArgumentException("labelSmoothing must be >= 0. and <= 1, found " + f);
        }
        this.fromLogits = z;
        this.labelSmoothing = f;
    }

    @Override // org.tensorflow.framework.losses.Loss
    public <T extends TNumber> Operand<T> call(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, Operand<T> operand3) {
        return LossesHelper.computeWeightedLoss(ops, Losses.binaryCrossentropy(ops, operand, !this.fromLogits ? LossesHelper.rangeCheck(ops, "predictions range check [0-1]", operand2, CastHelper.cast(ops, ops.constant(0), operand2.type()), CastHelper.cast(ops, ops.constant(1), operand2.type())) : operand2, this.fromLogits, this.labelSmoothing), getReduction(), operand3);
    }
}
