package org.tensorflow.framework.metrics;

import org.tensorflow.Operand;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/Metrics.class */
public class Metrics {
    public static <T extends TNumber> Operand<T> topKCategoricalAccuracy(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, long j) {
        return CastHelper.cast(ops, ops.nn.inTopK(CastHelper.cast(ops, operand2, TFloat32.class), ops.math.argMax(operand, ops.constant(-1)), ops.constant(j)), operand2.type());
    }

    public static <T extends TNumber, U extends TNumber> Operand<T> sparseTopKCategoricalAccuracy(Ops ops, Operand<U> operand, Operand<T> operand2, int i) {
        Operand cast = CastHelper.cast(ops, operand, operand2.type());
        int numDimensions = operand2.shape().numDimensions();
        int numDimensions2 = cast.shape().numDimensions();
        Operand cast2 = CastHelper.cast(ops, operand2, TFloat32.class);
        if (numDimensions != Shape.UNKNOWN_SIZE && numDimensions2 != Shape.UNKNOWN_SIZE) {
            if (numDimensions > 2) {
                cast2 = ops.reshape(cast2, ops.constant(cast2.shape().takeLast(1).prepend(Shape.UNKNOWN_SIZE)));
            }
            if (numDimensions2 > 1) {
                cast = ops.reshape(cast, ops.constant(Shape.of(new long[]{Shape.UNKNOWN_SIZE})));
            }
        }
        return CastHelper.cast(ops, ops.nn.inTopK(cast2, CastHelper.cast(ops, cast, TInt32.class), ops.constant(i)), operand2.type());
    }
}
