package org.tensorflow.framework.losses.impl;

import java.util.Arrays;
import java.util.Collections;
import org.tensorflow.Operand;
import org.tensorflow.framework.losses.Reduction;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.AssertThat;
import org.tensorflow.op.core.Rank;
import org.tensorflow.op.core.ReduceAll;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.SetDiff1d;
import org.tensorflow.op.core.Squeeze;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.LogicalAnd;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/losses/impl/LossesHelper.class */
public class LossesHelper {
    public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(Ops ops, Operand<T> operand, Operand<T> operand2) {
        return squeezeOrExpandDimensions(ops, operand, operand2, null);
    }

    public static <T extends TNumber> LossTuple<T> squeezeOrExpandDimensions(Ops ops, Operand<T> operand, Operand<T> operand2, Operand<T> operand3) {
        Shape shape = operand2.shape();
        long numDimensions = shape.numDimensions();
        LossTuple<T> lossTuple = new LossTuple<>(operand, operand2, operand3);
        if (operand != null) {
            long numDimensions2 = operand.shape().numDimensions();
            if (numDimensions2 == Shape.UNKNOWN_SIZE || numDimensions == Shape.UNKNOWN_SIZE) {
                lossTuple = removeSqueezableDimensions(ops, operand, operand2);
            } else if (numDimensions - numDimensions2 != 1 || shape.size(-1) == 1) {
                lossTuple = removeSqueezableDimensions(ops, operand, operand2);
            }
        }
        if (operand3 == null) {
            return lossTuple;
        }
        long numDimensions3 = operand3.shape().numDimensions();
        if (numDimensions3 == 0) {
            return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), operand3);
        }
        if (numDimensions == Shape.UNKNOWN_SIZE || numDimensions3 == Shape.UNKNOWN_SIZE) {
            Rank rank = ops.rank(operand3);
            return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), ops.select(ops.math.equal(rank, ops.constant(0), new Equal.Options[0]), operand3, maybeAdjustWeights(ops, operand3, ops.math.sub(rank, ops.rank(operand2)))));
        }
        if (numDimensions3 - numDimensions == 1) {
            operand3 = ops.squeeze(operand3, new Squeeze.Options[0]);
        } else if (numDimensions - numDimensions3 == 1) {
            operand3 = ops.expandDims(operand3, ops.constant(-1L));
        }
        return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), operand3);
    }

    private static <T extends TNumber> Operand<T> maybeAdjustWeights(Ops ops, Operand<T> operand, Operand<TInt32> operand2) {
        return ops.select(ops.math.equal(operand2, ops.constant(1), new Equal.Options[0]), ops.squeeze(operand, new Squeeze.Options[]{Squeeze.axis(Collections.singletonList(-1L))}), maybeExpandWeights(ops, operand, operand2));
    }

    private static <T extends TNumber> Operand<T> maybeExpandWeights(Ops ops, Operand<T> operand, Operand<TInt32> operand2) {
        return ops.select(ops.math.equal(operand2, ops.constant(-1), new Equal.Options[0]), ops.expandDims(operand, ops.constant(-1)), operand);
    }

    public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(Ops ops, Operand<T> operand, Operand<T> operand2) {
        return removeSqueezableDimensions(ops, operand, operand2, 0);
    }

    public static <T extends TNumber> LossTuple<T> removeSqueezableDimensions(Ops ops, Operand<T> operand, Operand<T> operand2, int i) {
        Ops withSubScope = ops.withSubScope("removeSqueezableDimensions");
        Shape shape = operand2.shape();
        int numDimensions = shape.numDimensions();
        Shape shape2 = operand.shape();
        int numDimensions2 = shape2.numDimensions();
        if (numDimensions == Shape.UNKNOWN_SIZE && numDimensions2 == Shape.UNKNOWN_SIZE) {
            if (numDimensions == Shape.UNKNOWN_SIZE && Shape.isCompatible(shape.size(-1), 1L)) {
                operand2 = withSubScope.squeeze(operand2, new Squeeze.Options[]{Squeeze.axis(Collections.singletonList(-1L))});
            }
            if (numDimensions2 == Shape.UNKNOWN_SIZE && Shape.isCompatible(shape2.size(-1), 1L)) {
                operand = withSubScope.squeeze(operand, new Squeeze.Options[]{Squeeze.axis(Collections.singletonList(-1L))});
            }
            return new LossTuple<>(operand, operand2);
        }
        int i2 = numDimensions - numDimensions2;
        if (i2 == i + 1 && Shape.isCompatible(shape.size(-1), 1L)) {
            operand2 = withSubScope.squeeze(operand2, new Squeeze.Options[0]);
        } else if (i2 == i - 1 && Shape.isCompatible(shape2.size(-1), 1L)) {
            operand = withSubScope.squeeze(operand, new Squeeze.Options[0]);
        }
        return new LossTuple<>(operand, operand2);
    }

    public static <T extends TNumber> Operand<T> computeWeightedLoss(Ops ops, Operand<T> operand, Reduction reduction, Operand<T> operand2) {
        Class type = operand.type();
        if (operand2 == null) {
            operand2 = CastHelper.cast(ops, ops.constant(1), type);
        }
        LossTuple squeezeOrExpandDimensions = squeezeOrExpandDimensions(ops, null, operand, operand2);
        return CastHelper.cast(ops, reduceWeightedLoss(ops, ops.math.mul(squeezeOrExpandDimensions.getTarget(), CastHelper.cast(ops, squeezeOrExpandDimensions.getSampleWeights(), type)), reduction), type);
    }

    private static <T extends TNumber> Operand<T> reduceWeightedLoss(Ops ops, Operand<T> operand, Reduction reduction) {
        Operand<T> reduceSum;
        if (reduction == Reduction.NONE) {
            reduceSum = operand;
        } else {
            reduceSum = ops.reduceSum(operand, allAxes(ops, operand), new ReduceSum.Options[]{ReduceSum.keepDims(Boolean.FALSE)});
            if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) {
                reduceSum = safeMean(ops, reduceSum, operand.shape().size());
            }
        }
        return reduceSum;
    }

    public static <T extends TNumber> Operand<T> safeMean(Ops ops, Operand<T> operand, long j) {
        return ops.math.divNoNan(ops.reduceSum(operand, allAxes(ops, operand), new ReduceSum.Options[0]), CastHelper.cast(ops, ops.constant(j), operand.type()));
    }

    public static <T extends TNumber> Operand<TInt32> allAxes(Ops ops, Operand<T> operand) {
        int numDimensions = operand.shape().numDimensions();
        if (numDimensions == Shape.UNKNOWN_SIZE) {
            return ops.range(ops.constant(0), ops.rank(operand), ops.constant(1));
        }
        int[] iArr = new int[numDimensions];
        for (int i = 0; i < numDimensions; i++) {
            iArr[i] = i;
        }
        return ops.constant(iArr);
    }

    public static <T extends TNumber> Operand<T> rangeCheck(Ops ops, String str, Operand<T> operand, Operand<T> operand2, Operand<T> operand3) {
        Operand<TInt32> allAxes = allAxes(ops, operand);
        LogicalAnd logicalAnd = ops.math.logicalAnd(ops.reduceAll(ops.math.greaterEqual(operand, operand2), allAxes, new ReduceAll.Options[0]), ops.reduceAll(ops.math.lessEqual(operand, operand3), allAxes, new ReduceAll.Options[0]));
        if (ops.scope().env().isGraph()) {
            return ops.withSubScope("rangeCheck").withControlDependencies(Collections.singletonList(ops.assertThat(logicalAnd, Arrays.asList(ops.constant(str), ops.constant(": values out of range, "), ops.constant("minimum = "), operand2, ops.constant(", maximum = "), operand3), new AssertThat.Options[0]))).identity(operand);
        }
        if (logicalAnd.asTensor().getBoolean(new long[0])) {
            return operand;
        }
        throw new IllegalArgumentException(String.format("%s : values out of range", str));
    }

    public static <T extends TNumber> Operand<T> valueCheck(Ops ops, String str, Operand<T> operand, Operand<T> operand2) {
        SetDiff1d diff1d = ops.setDiff1d(ops.reshape(operand, ops.constant(Shape.of(new long[]{operand.shape().size()}))), operand2, TInt32.class);
        long size = diff1d.out().shape().size();
        if (size != Shape.UNKNOWN_SIZE) {
            if (size != 0) {
                throw new IllegalArgumentException(String.format("%s : values not in value set,", str));
            }
            return operand;
        }
        Equal equal = ops.math.equal(ops.shape.size(ops.shape(diff1d.out())), ops.constant(0), new Equal.Options[0]);
        if (ops.scope().env().isGraph()) {
            return ops.withSubScope("valueCheck").withControlDependencies(Collections.singletonList(ops.assertThat(equal, Arrays.asList(ops.constant(str), ops.constant(": values not in value set, values = "), operand), new AssertThat.Options[0]))).identity(operand);
        }
        if (equal.asTensor().getBoolean(new long[0])) {
            return operand;
        }
        throw new IllegalArgumentException(String.format("%s : values not in value set", str));
    }
}
