package org.tensorflow.framework.metrics;

import org.tensorflow.Operand;
import org.tensorflow.framework.metrics.impl.LossMetric;
import org.tensorflow.framework.metrics.impl.MeanBaseMetricWrapper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.class */
public class SparseTopKCategoricalAccuracy<T extends TNumber> extends MeanBaseMetricWrapper<T> implements LossMetric {
    public static final int DEFAULT_K = 5;
    private final int k;

    public SparseTopKCategoricalAccuracy(long j, Class<T> cls) {
        this(null, 5, j, cls);
    }

    public SparseTopKCategoricalAccuracy(int i, long j, Class<T> cls) {
        this(null, i, j, cls);
    }

    public SparseTopKCategoricalAccuracy(String str, long j, Class<T> cls) {
        this(str, 5, j, cls);
    }

    public SparseTopKCategoricalAccuracy(String str, int i, long j, Class<T> cls) {
        super(str, j, cls);
        this.k = i;
        setLoss(this);
    }

    @Override // org.tensorflow.framework.metrics.impl.LossMetric
    public <U extends TNumber> Operand<U> call(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2, Class<U> cls) {
        init(ops);
        return CastHelper.cast(ops, Metrics.sparseTopKCategoricalAccuracy(ops, CastHelper.cast(ops, operand, getInternalType()), CastHelper.cast(ops, operand2, getInternalType()), this.k), cls);
    }
}
