package org.tensorflow.framework.metrics.impl;

import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.metrics.Mean;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/impl/MeanBaseMetricWrapper.class */
public class MeanBaseMetricWrapper<T extends TNumber> extends Mean<T> {
    protected LossMetric loss;

    /* JADX INFO: Access modifiers changed from: protected */
    public MeanBaseMetricWrapper(String str, long j, Class<T> cls) {
        super(str, j, cls);
    }

    public LossMetric getLoss() {
        return this.loss;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setLoss(LossMetric lossMetric) {
        this.loss = lossMetric;
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric, org.tensorflow.framework.metrics.Metric
    public List<Op> updateStateList(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2, Operand<? extends TNumber> operand3) {
        if (operand == null || operand2 == null) {
            throw new IllegalArgumentException("missing required inputs for labels and predictions");
        }
        init(ops);
        return super.updateStateList(ops, CastHelper.cast(ops, this.loss.call(ops, CastHelper.cast(ops, operand, getInternalType()), CastHelper.cast(ops, operand2, getInternalType()), getInternalType()), operand2.type()), operand3);
    }
}
