package org.tensorflow.framework.op.math;

import org.tensorflow.Operand;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Range;
import org.tensorflow.op.core.Rank;
import org.tensorflow.op.core.ReduceMax;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Select;
import org.tensorflow.op.core.Shape;
import org.tensorflow.op.core.StopGradient;
import org.tensorflow.op.core.ZerosLike;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Exp;
import org.tensorflow.op.math.IsFinite;
import org.tensorflow.op.math.Log;
import org.tensorflow.op.math.Sub;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TFloating;

/* loaded from: input_file:org/tensorflow/framework/op/math/ReduceLogSumExp.class */
public class ReduceLogSumExp {
    public static <T extends TFloating> Operand<T> reduceLogSumExp(Scope scope, Operand<T> operand, int[] iArr, boolean z) {
        Operand<TInt32> reductionDims = reductionDims(scope, operand, iArr);
        Operand reduceMaxWithDims = reduceMaxWithDims(scope, operand, iArr, z, reductionDims);
        Operand create = StopGradient.create(scope, Select.create(scope, IsFinite.create(scope, reduceMaxWithDims), reduceMaxWithDims, ZerosLike.create(scope, reduceMaxWithDims)));
        Log create2 = Log.create(scope, reduceSumWithDims(scope, Exp.create(scope, Sub.create(scope, operand, create)), iArr, z, reductionDims));
        if (!z) {
            create = Reshape.create(scope, create, Shape.create(scope, create2));
        }
        return mayReduceToScalar(scope, z, iArr, Add.create(scope, create2, create));
    }

    private static <T extends TFloating> Operand<T> reduceSumWithDims(Scope scope, Operand<T> operand, int[] iArr, boolean z, Operand<TInt32> operand2) {
        return mayReduceToScalar(scope, z, iArr, ReduceSum.create(scope, operand, operand2, new ReduceSum.Options[]{ReduceSum.keepDims(Boolean.valueOf(z))}));
    }

    private static <T extends TFloating> Operand<T> reduceMaxWithDims(Scope scope, Operand<T> operand, int[] iArr, boolean z, Operand<TInt32> operand2) {
        return mayReduceToScalar(scope, z, iArr, ReduceMax.create(scope, operand, operand2, new ReduceMax.Options[]{ReduceMax.keepDims(Boolean.valueOf(z))}));
    }

    private static <T extends TFloating> Operand<T> mayReduceToScalar(Scope scope, boolean z, int[] iArr, Operand<T> operand) {
        return ((((long) operand.shape().numDimensions()) == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE || operand.shape().hasUnknownDimension()) && !z && iArr == null) ? Reshape.create(scope, operand, Constant.tensorOf(scope, org.tensorflow.ndarray.Shape.scalar())) : operand;
    }

    private static <T extends TFloating> Operand<TInt32> reductionDims(Scope scope, Operand<T> operand, int[] iArr) {
        if (iArr != null) {
            return Constant.vectorOf(scope, iArr);
        }
        long numDimensions = operand.shape().numDimensions();
        if (numDimensions == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) {
            return Range.create(scope, Constant.scalarOf(scope, 0), Rank.create(scope, operand), Constant.scalarOf(scope, 1));
        }
        int[] iArr2 = new int[(int) numDimensions];
        for (int i = 0; i < numDimensions; i++) {
            iArr2[i] = i;
        }
        return Constant.vectorOf(scope, iArr2);
    }
}
