package org.tensorflow.framework.losses;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossTuple;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.op.FrameworkOps;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ClipByValue;
import org.tensorflow.op.core.ReduceAll;
import org.tensorflow.op.core.ReduceMax;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.math.Abs;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.Mean;
import org.tensorflow.op.math.Minimum;
import org.tensorflow.op.math.Sub;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/losses/Losses.class */
public class Losses {
    public static final float EPSILON = 1.0E-7f;
    public static final int CHANNELS_LAST = -1;
    public static final int CHANNELS_FIRST = 1;

    public static <T extends TNumber> Operand<T> meanAbsoluteError(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, operand2.type()), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        return ops.math.mean(ops.math.abs(ops.math.sub(squeezeOrExpandDimensions.getLabels(), target)), ops.constant(-1), new Mean.Options[]{Mean.keepDims(false)});
    }

    public static <T extends TNumber> Operand<T> meanSquaredError(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, operand2.type()), operand2, null);
        return ops.math.mean(ops.math.squaredDifference(squeezeOrExpandDimensions.getTarget(), squeezeOrExpandDimensions.getLabels()), ops.constant(-1), new Mean.Options[0]);
    }

    public static <T extends TNumber> Operand<T> meanAbsolutePercentageError(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        return ops.math.mul(CastHelper.cast(ops, ops.constant(100), type), ops.math.mean(ops.math.abs(ops.math.div(ops.math.sub(labels, target), ops.math.maximum(ops.math.abs(labels), CastHelper.cast(ops, ops.constant(1.0E-7f), type)))), ops.constant(-1), new Mean.Options[0]));
    }

    public static <T extends TNumber> Operand<T> meanSquaredLogarithmicError(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        Operand cast = CastHelper.cast(ops, ops.constant(1.0E-7f), type);
        Operand cast2 = CastHelper.cast(ops, ops.constant(1), type);
        return ops.math.mean(ops.math.squaredDifference(ops.math.log(ops.math.add(ops.math.maximum(target, cast), cast2)), ops.math.log(ops.math.add(ops.math.maximum(labels, cast), cast2))), ops.constant(-1), new Mean.Options[0]);
    }

    public static <T extends TNumber> Operand<T> binaryCrossentropy(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, boolean z, float f) {
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, operand2.type()), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        if (f != 0.0f) {
            labels = smoothBinaryLabels(ops, labels, f);
        }
        return ops.math.mean(binaryCrossentropyHelper(ops, labels, target, z), ops.constant(-1), new Mean.Options[0]);
    }

    private static <T extends TNumber> Operand<T> binaryCrossentropyHelper(Ops ops, Operand<T> operand, Operand<T> operand2, boolean z) {
        FrameworkOps create = FrameworkOps.create(ops);
        if (z) {
            return create.nn.sigmoidCrossEntropyWithLogits(operand, operand2);
        }
        Class type = operand2.type();
        Operand cast = CastHelper.cast(ops, ops.constant(1), type);
        Operand cast2 = CastHelper.cast(ops, ops.constant(1.0E-7f), type);
        ClipByValue clipByValue = ops.clipByValue(operand2, cast2, ops.math.sub(cast, cast2));
        return ops.math.neg(ops.math.add(ops.math.mul(operand, ops.math.log(ops.math.add(clipByValue, cast2))), ops.math.mul(ops.math.sub(cast, operand), ops.math.log(ops.math.add(ops.math.sub(cast, clipByValue), cast2)))));
    }

    public static <T extends TNumber> Operand<T> categoricalCrossentropy(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, boolean z, float f, int i) {
        FrameworkOps create = FrameworkOps.create(ops);
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        if (f != 0.0f) {
            labels = smoothCategoricalLabels(ops, labels, f);
        }
        if (z) {
            return create.nn.softmaxCrossEntropyWithLogits(labels, target, i);
        }
        Operand cast = CastHelper.cast(ops, ops.constant(1), type);
        Operand cast2 = CastHelper.cast(ops, ops.constant(1.0E-7f), type);
        return ops.math.neg(ops.reduceSum(ops.math.mul(labels, ops.math.log(ops.clipByValue(ops.math.div(target, ops.reduceSum(target, ops.constant(i), new ReduceSum.Options[]{ReduceSum.keepDims(true)})), cast2, ops.math.sub(cast, cast2)))), ops.constant(i), new ReduceSum.Options[]{ReduceSum.keepDims(false)}));
    }

    public static <T extends TNumber> Operand<T> categoricalHinge(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        Operand cast = CastHelper.cast(ops, ops.constant(1), type);
        return ops.math.maximum(CastHelper.cast(ops, ops.constant(0), type), ops.math.add(ops.math.sub(ops.reduceMax(ops.math.mul(ops.math.sub(cast, labels), target), ops.constant(-1), new ReduceMax.Options[]{ReduceMax.keepDims(Boolean.FALSE)}), ops.reduceSum(ops.math.mul(labels, target), ops.constant(-1), new ReduceSum.Options[]{ReduceSum.keepDims(Boolean.FALSE)})), cast));
    }

    public static <T extends TNumber> Operand<T> cosineSimilarity(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, int[] iArr) {
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, operand2.type()), operand2, null);
        return ops.reduceSum(ops.math.mul(l2Normalize(ops, squeezeOrExpandDimensions.getLabels(), iArr), l2Normalize(ops, squeezeOrExpandDimensions.getTarget(), iArr)), ops.constant(iArr), new ReduceSum.Options[]{ReduceSum.keepDims(Boolean.FALSE)});
    }

    public static <T extends TNumber> Operand<T> hinge(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        Operand cast = CastHelper.cast(ops, ops.constant(1), type);
        return ops.math.mean(ops.math.maximum(ops.math.sub(cast, ops.math.mul(maybeConvertLabels(ops, labels), target)), CastHelper.cast(ops, ops.constant(0), type)), ops.constant(-1), new Mean.Options[0]);
    }

    public static <T extends TNumber> Operand<T> huber(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, float f) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Sub sub = ops.math.sub(squeezeOrExpandDimensions.getTarget(), squeezeOrExpandDimensions.getLabels());
        Operand cast = CastHelper.cast(ops, ops.constant(f), type);
        Operand cast2 = CastHelper.cast(ops, ops.constant(0.5d), type);
        Abs abs = ops.math.abs(sub);
        Minimum minimum = ops.math.minimum(abs, cast);
        Sub sub2 = ops.math.sub(abs, minimum);
        return ops.math.mean(ops.math.add(ops.math.mul(cast2, ops.math.mul(minimum, minimum)), ops.math.mul(cast, sub2)), ops.constant(-1), new Mean.Options[0]);
    }

    public static <T extends TNumber> Operand<T> kullbackLeiblerDivergence(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        Operand cast = CastHelper.cast(ops, ops.constant(1), type);
        Operand cast2 = CastHelper.cast(ops, ops.constant(1.0E-7f), type);
        ClipByValue clipByValue = ops.clipByValue(labels, cast2, cast);
        return ops.reduceSum(ops.math.mul(clipByValue, ops.math.log(ops.math.div(clipByValue, ops.clipByValue(target, cast2, cast)))), ops.constant(-1), new ReduceSum.Options[0]);
    }

    public static <T extends TNumber> Operand<T> logCosh(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        Operand cast = CastHelper.cast(ops, ops.constant(-2), type);
        Operand cast2 = CastHelper.cast(ops, ops.constant(2), type);
        Sub sub = ops.math.sub(target, labels);
        return ops.math.mean(ops.math.sub(ops.math.add(sub, ops.math.softplus(ops.math.mul(cast, sub))), ops.math.log(cast2)), ops.constant(-1), new Mean.Options[0]);
    }

    public static <T extends TNumber> Operand<T> poisson(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        return ops.math.mean(ops.math.sub(target, ops.math.mul(squeezeOrExpandDimensions.getLabels(), ops.math.log(ops.math.add(target, CastHelper.cast(ops, ops.constant(1.0E-7f), type))))), ops.constant(-1), new Mean.Options[0]);
    }

    public static <T extends TNumber> Operand<T> sparseCategoricalCrossentropy(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, boolean z, int i) {
        FrameworkOps create = FrameworkOps.create(ops);
        Class type = operand2.type();
        Operand cast = CastHelper.cast(ops, ops.constant(1.0E-7f), type);
        Sub sub = ops.math.sub(CastHelper.cast(ops, ops.constant(1), type), cast);
        if (!z) {
            operand2 = ops.math.log(ops.clipByValue(operand2, cast, sub));
        }
        Shape shape = operand2.shape();
        int numDimensions = shape.numDimensions();
        int i2 = i % numDimensions;
        if (i2 < 0) {
            i2 += numDimensions;
        }
        if (i2 != numDimensions - 1) {
            operand2 = ops.linalg.transpose(operand2, ops.constant(moveAxisToEnd(i2, numDimensions)));
        }
        Operand cast2 = CastHelper.cast(ops, operand, TInt64.class);
        Shape shape2 = operand.shape();
        int numDimensions2 = shape2.numDimensions();
        boolean z2 = numDimensions2 != numDimensions - 1;
        if (z2) {
            cast2 = ops.reshape(cast2, ops.constant(shape2.take(numDimensions2 - 1)));
            operand2 = ops.reshape(operand2, ops.constant(new long[]{-1, shape.size(shape.numDimensions() - 1)}));
        }
        Operand sparseSoftmaxCrossEntropyWithLogits = create.nn.sparseSoftmaxCrossEntropyWithLogits(cast2, operand2);
        if (z2 && numDimensions >= 3) {
            sparseSoftmaxCrossEntropyWithLogits = ops.reshape(sparseSoftmaxCrossEntropyWithLogits, ops.constant(shape.take(shape.numDimensions() - 1)));
        }
        return sparseSoftmaxCrossEntropyWithLogits;
    }

    public static <T extends TNumber> Operand<T> squaredHinge(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2) {
        Class type = operand2.type();
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, CastHelper.cast(ops, operand, type), operand2, null);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        Operand cast = CastHelper.cast(ops, ops.constant(1), type);
        return ops.math.mean(ops.math.square(ops.math.maximum(ops.math.sub(cast, ops.math.mul(maybeConvertLabels(ops, labels), target)), CastHelper.cast(ops, ops.constant(0), type))), ops.constant(-1), new Mean.Options[0]);
    }

    private static <T extends TNumber> Operand<T> smoothBinaryLabels(Ops ops, Operand<T> operand, float f) {
        Class type = operand.type();
        Operand cast = CastHelper.cast(ops, ops.constant(1.0f - f), type);
        return ops.math.add(ops.math.mul(operand, cast), CastHelper.cast(ops, ops.constant(0.5f * f), type));
    }

    private static <T extends TNumber> Operand<T> smoothCategoricalLabels(Ops ops, Operand<T> operand, float f) {
        Class type = operand.type();
        Operand cast = CastHelper.cast(ops, ops.constant(f), type);
        Shape shape = operand.shape();
        Operand cast2 = CastHelper.cast(ops, ops.constant(shape.size(shape.numDimensions() - 1)), type);
        return ops.math.add(ops.math.mul(operand, CastHelper.cast(ops, ops.constant(1.0f - f), type)), ops.math.div(cast, cast2));
    }

    public static <T extends TNumber> Operand<T> l2Normalize(Ops ops, Operand<T> operand, int[] iArr) {
        return ops.math.mul(operand, ops.math.rsqrt(ops.math.maximum(ops.reduceSum(ops.math.square(operand), ops.constant(iArr), new ReduceSum.Options[]{ReduceSum.keepDims(Boolean.TRUE)}), CastHelper.cast(ops, ops.constant(1.0E-12f), operand.type()))));
    }

    private static <T extends TNumber> Operand<T> maybeConvertLabels(Ops ops, Operand<T> operand) {
        Class type = operand.type();
        Operand cast = CastHelper.cast(ops, ops.constant(1), type);
        return ops.select(ops.reduceAll(ops.math.logicalOr(ops.math.equal(operand, CastHelper.cast(ops, ops.constant(0), type), new Equal.Options[0]), ops.math.equal(operand, cast, new Equal.Options[0])), ops.constant(-1), new ReduceAll.Options[]{ReduceAll.keepDims(true)}), ops.math.sub(ops.math.mul(CastHelper.cast(ops, ops.constant(2), type), operand), cast), operand);
    }

    private static int[] moveAxisToEnd(int i, int i2) {
        int[] iArr = new int[i2];
        for (int i3 = 0; i3 < i; i3++) {
            iArr[i3] = i3;
        }
        for (int i4 = i + 1; i4 < i2; i4++) {
            iArr[i4 - 1] = i4;
        }
        iArr[i2 - 1] = i;
        return iArr;
    }
}
