package org.tensorflow.framework.metrics.impl;

import java.util.Arrays;
import java.util.Collections;
import org.tensorflow.Operand;
import org.tensorflow.framework.op.FrameworkOps;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.AssertThat;
import org.tensorflow.op.core.Rank;
import org.tensorflow.op.math.Equal;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.class */
public class WeightsBroadcastOps {
    private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = "weights can not be broadcast to values.";

    public static <T extends TNumber> Op assertBroadcastable(Ops ops, Operand<T> operand, Operand<T> operand2) {
        Operand shape = ops.shape(operand);
        Rank rank = ops.rank(operand);
        Shape shape2 = operand.shape();
        int numDimensions = shape2.numDimensions();
        Operand shape3 = ops.shape(operand2);
        Rank rank2 = ops.rank(operand2);
        Shape shape4 = operand2.asOutput().shape();
        int numDimensions2 = shape4.numDimensions();
        if (numDimensions == -1 || numDimensions2 == -1) {
            Operand equal = ops.math.equal(rank, ops.constant(0), new Equal.Options[0]);
            return ops.assertThat(ops.select(equal, equal, hasValidNonscalarShape(ops, rank, shape, rank2, shape3)), Arrays.asList(ops.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), ops.constant("weights.shape="), shape, ops.constant("values.shape="), shape3, ops.constant("isScalar="), equal), new AssertThat.Options[0]);
        }
        if (numDimensions == 0) {
            return ops.withSubScope("staticScalarCheckSuccess").withControlDependencies(Collections.emptyList()).noOp();
        }
        if (numDimensions != numDimensions2) {
            throw new IllegalArgumentException(String.format("%s values.rank=%d. weights.rank=%d.  values.shape=%s. weights.shape=%s.", ASSERT_BROADCASTABLE_ERROR_PREFIX, Integer.valueOf(numDimensions2), Integer.valueOf(numDimensions), shape4, shape2));
        }
        for (int i = 0; i < numDimensions2; i++) {
            if (shape2.size(i) != 1 && shape4.size(i) != shape2.size(i)) {
                throw new IllegalArgumentException(String.format("%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", ASSERT_BROADCASTABLE_ERROR_PREFIX, Integer.valueOf(i), shape4, shape2));
            }
        }
        return ops.withSubScope("staticDimsCheckSuccess").withControlDependencies(Collections.emptyList()).noOp();
    }

    private static Operand<TBool> hasValidNonscalarShape(Ops ops, Operand<TInt32> operand, Operand<TInt32> operand2, Operand<TInt32> operand3, Operand<TInt32> operand4) {
        Ops withSubScope = ops.withSubScope("hasValidNonscalarShape");
        Equal equal = withSubScope.math.equal(operand3, operand, new Equal.Options[0]);
        return withSubScope.select(equal, hasValidDims(withSubScope, operand2, operand4), equal);
    }

    private static Operand<TBool> hasValidDims(Ops ops, Operand<TInt32> operand, Operand<TInt32> operand2) {
        Ops withSubScope = ops.withSubScope("hasInvalidDims");
        Operand expandDims = withSubScope.expandDims(operand2, withSubScope.constant(-1));
        return withSubScope.math.equal(withSubScope.constant(0), withSubScope.size(FrameworkOps.create(withSubScope).sets.difference(withSubScope.expandDims(operand, withSubScope.constant(-1)), withSubScope.concat(Arrays.asList(expandDims, withSubScope.onesLike(expandDims)), withSubScope.constant(1))), TInt32.class), new Equal.Options[0]);
    }

    public static <T extends TNumber> Operand<T> broadcastWeights(Ops ops, Operand<T> operand, Operand<T> operand2) {
        Ops withSubScope = ops.withSubScope("broadcast_weights");
        Operand cast = CastHelper.cast(withSubScope, operand2, operand.type());
        Shape shape = operand.shape();
        Shape shape2 = cast.shape();
        if (!shape.hasUnknownDimension() && !shape2.hasUnknownDimension() && shape.isCompatibleWith(shape2)) {
            return operand;
        }
        return withSubScope.withSubScope("assertBroadcastable").withControlDependencies(Collections.singletonList(assertBroadcastable(withSubScope, operand, cast))).math.mul(operand, withSubScope.onesLike(cast));
    }
}
