package org.tensorflow.framework.metrics;

import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossTuple;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/RootMeanSquaredError.class */
public class RootMeanSquaredError<T extends TNumber> extends Mean<T> {
    public RootMeanSquaredError(long j, Class<T> cls) {
        this(null, j, cls);
    }

    public RootMeanSquaredError(String str, long j, Class<T> cls) {
        super(str, j, cls);
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric, org.tensorflow.framework.metrics.Metric
    public List<Op> updateStateList(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2, Operand<? extends TNumber> operand3) {
        init(ops);
        Operand cast = CastHelper.cast(ops, operand, getInternalType());
        Operand cast2 = CastHelper.cast(ops, operand2, getInternalType());
        Operand<? extends TNumber> cast3 = operand3 != null ? CastHelper.cast(ops, operand3, getInternalType()) : null;
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, cast, cast2);
        return super.updateStateList(ops, CastHelper.cast(ops, ops.math.squaredDifference(squeezeOrExpandDimensions.getTarget(), squeezeOrExpandDimensions.getLabels()), getInternalType()), cast3);
    }

    @Override // org.tensorflow.framework.metrics.impl.Reduce, org.tensorflow.framework.metrics.Metric
    public <U extends TNumber> Operand<U> result(Ops ops, Class<U> cls) {
        init(ops);
        return CastHelper.cast(ops, ops.math.sqrt(ops.math.divNoNan(this.total, this.count)), cls);
    }
}
