package org.tensorflow.framework.metrics.impl;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.AssertThat;
import org.tensorflow.op.core.Rank;
import org.tensorflow.op.core.Select;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.Mean;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.family.TIntegral;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/impl/MetricsHelper.class */
public class MetricsHelper {
    public static final float NEG_INF = -1.0E10f;
    private static final String ASSERT_BROADCAST_ERROR_PREFIX = "weights can not be broadcast to values.";

    public static <T extends TNumber> Op assertBroadcastable(Ops ops, Operand<T> operand, Operand<T> operand2) {
        Shape shape = operand.shape();
        int numDimensions = shape.numDimensions();
        Shape shape2 = operand2.shape();
        int numDimensions2 = shape2.numDimensions();
        if (shape.isUnknown() || shape2.isUnknown() || shape.hasUnknownDimension() || shape2.hasUnknownDimension()) {
            Operand shape3 = ops.shape(operand);
            Rank rank = ops.rank(operand);
            Operand shape4 = ops.shape(operand2);
            Rank rank2 = ops.rank(operand2);
            Operand equal = ops.math.equal(rank, ops.constant(0), new Equal.Options[0]);
            List asList = Arrays.asList(ops.constant(ASSERT_BROADCAST_ERROR_PREFIX), ops.constant("weights.shape="), shape3, ops.constant("values.shape="), shape4, ops.constant("isScalar="), equal);
            Select select = ops.select(equal, ops.math.mul(operand, ops.onesLike(operand2)), operand);
            return ops.withSubScope("broadcastWeights-dynamic").assertThat(ops.select(equal, equal, canBroadcastNonscalarShapes(ops, ops.rank(select), ops.shape(select), rank2, shape4)), asList, new AssertThat.Options[0]);
        }
        if (numDimensions == 0) {
            return ops.withSubScope("staticScalarCheckSuccess").withControlDependencies(Collections.EMPTY_LIST).noOp();
        }
        if (numDimensions != numDimensions2) {
            throw new NotBroadcastableException(String.format("%s values.rank=%d. weights.rank=%d.  values.shape=%s. weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, Integer.valueOf(numDimensions2), Integer.valueOf(numDimensions), shape2.toString(), shape.toString()));
        }
        for (int i = 0; i < numDimensions2; i++) {
            if (shape2.size(i) != shape.size(i) && shape.size(i) != 1) {
                throw new NotBroadcastableException(String.format("%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, Integer.valueOf(i), shape2.toString(), shape.toString()));
            }
        }
        return ops.withSubScope("staticDimsCheckSuccess").withControlDependencies(Collections.EMPTY_LIST).noOp();
    }

    private static <T extends TNumber> Operand<TBool> canBroadcastNonscalarShapes(Ops ops, Operand<T> operand, Operand<T> operand2, Operand<T> operand3, Operand<T> operand4) {
        Ops withSubScope = ops.withSubScope("canBroadcastNonscalarShapes");
        Equal equal = withSubScope.math.equal(operand3, operand, new Equal.Options[0]);
        return withSubScope.select(equal, canBroadcastDims(withSubScope, operand2, operand4), equal);
    }

    private static <T extends TNumber> Operand<TBool> canBroadcastDims(Ops ops, Operand<T> operand, Operand<T> operand2) {
        Ops withSubScope = ops.withSubScope("canBroadcastDims");
        Operand expandDims = withSubScope.expandDims(operand2, withSubScope.constant(-1));
        return withSubScope.math.equal(withSubScope.constant(0), withSubScope.size(SetsOps.difference(withSubScope, withSubScope.expandDims(operand, withSubScope.constant(-1)), withSubScope.concat(Arrays.asList(expandDims, withSubScope.onesLike(expandDims)), withSubScope.constant(1)))), new Equal.Options[0]);
    }

    public static <T extends TNumber> Operand<T> broadcastWeights(Ops ops, Operand<T> operand, Operand<T> operand2) {
        Shape shape = operand.shape();
        Shape shape2 = operand2.shape();
        return (shape.hasUnknownDimension() || shape2.hasUnknownDimension() || !shape.isCompatibleWith(shape2)) ? ops.withSubScope("broadcastWeights").withControlDependencies(Collections.singletonList(assertBroadcastable(ops, operand, ops.onesLike(operand2)))).math.mul(operand, ops.onesLike(operand2)) : operand;
    }

    public static <T extends TNumber> Operand<T> mean(Ops ops, Operand<T> operand) {
        return mean(ops, operand, null, false);
    }

    public static <T extends TNumber> Operand<T> mean(Ops ops, Operand<T> operand, Operand<? extends TIntegral> operand2) {
        return mean(ops, operand, operand2, false);
    }

    public static <T extends TNumber> Operand<T> mean(Ops ops, Operand<T> operand, boolean z) {
        return mean(ops, operand, null, z);
    }

    public static <T extends TNumber> Operand<T> mean(Ops ops, Operand<T> operand, Operand<? extends TIntegral> operand2, boolean z) {
        if (operand2 == null) {
            operand2 = LossesHelper.allAxes(ops, operand);
        }
        return ops.math.mean(operand, operand2, new Mean.Options[]{Mean.keepDims(Boolean.valueOf(z))});
    }

    public static Operand<TFloat64> booleanMean(Ops ops, Operand<TBool> operand) {
        return booleanMean(ops, operand, null, false);
    }

    public static Operand<TFloat64> booleanMean(Ops ops, Operand<TBool> operand, Operand<? extends TIntegral> operand2) {
        return booleanMean(ops, operand, operand2, false);
    }

    public static Operand<TFloat64> booleanMean(Ops ops, Operand<TBool> operand, boolean z) {
        return booleanMean(ops, operand, null, z);
    }

    public static Operand<TFloat64> booleanMean(Ops ops, Operand<TBool> operand, Operand<? extends TIntegral> operand2, boolean z) {
        return mean(ops, CastHelper.cast(ops, operand, TFloat64.class), operand2, z);
    }
}
