package org.tensorflow.framework.metrics;

import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossTuple;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/MeanRelativeError.class */
public class MeanRelativeError<T extends TNumber> extends Mean<T> {
    private Operand<T> normalizer;
    private float[] normalizerFloat;
    private double[] normalizerDouble;

    protected MeanRelativeError(float[] fArr, long j, Class<T> cls) {
        this((String) null, fArr, j, (Class) cls);
    }

    protected MeanRelativeError(String str, float[] fArr, long j, Class<T> cls) {
        super(str, j, cls);
        this.normalizerFloat = fArr;
    }

    protected MeanRelativeError(double[] dArr, long j, Class<T> cls) {
        this((String) null, dArr, j, cls);
    }

    protected MeanRelativeError(String str, double[] dArr, long j, Class<T> cls) {
        super(str, j, cls);
        this.normalizerDouble = dArr;
    }

    protected MeanRelativeError(Operand<T> operand, long j, Class<T> cls) {
        this((String) null, operand, j, cls);
    }

    protected MeanRelativeError(String str, Operand<T> operand, long j, Class<T> cls) {
        super(str, j, cls);
        this.normalizer = operand;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.tensorflow.framework.metrics.impl.Reduce, org.tensorflow.framework.metrics.BaseMetric
    public void init(Ops ops) {
        checkIsGraph(ops);
        if (isInitialized()) {
            return;
        }
        super.init(ops);
        if (this.normalizer == null) {
            if (this.normalizerDouble.length > 0) {
                this.normalizer = CastHelper.cast(ops, ops.constant(this.normalizerDouble), getInternalType());
            } else if (this.normalizerFloat.length > 0) {
                this.normalizer = CastHelper.cast(ops, ops.constant(this.normalizerFloat), getInternalType());
            }
        }
        setInitialized(true);
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric, org.tensorflow.framework.metrics.Metric
    public List<Op> updateStateList(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2, Operand<? extends TNumber> operand3) {
        init(ops);
        Operand cast = CastHelper.cast(ops, operand, getInternalType());
        Operand cast2 = CastHelper.cast(ops, operand2, getInternalType());
        Operand<? extends TNumber> cast3 = operand3 != null ? CastHelper.cast(ops, operand3, getInternalType()) : null;
        LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, cast, cast2);
        Operand<T> target = squeezeOrExpandDimensions.getTarget();
        Operand<T> labels = squeezeOrExpandDimensions.getLabels();
        LossTuple removeSqueezableDimensions = LossesHelper.removeSqueezableDimensions(ops, this.normalizer, target);
        this.normalizer = removeSqueezableDimensions.getLabels();
        Operand<T> target2 = removeSqueezableDimensions.getTarget();
        if (target2.shape().isCompatibleWith(labels.shape())) {
            return super.updateStateList(ops, ops.math.divNoNan(ops.math.abs(ops.math.sub(labels, target2)), getNormalizer()), cast3);
        }
        throw new IllegalArgumentException(String.format("Prediction shape %s is not compatible with labels shape %s", target2.shape(), labels.shape()));
    }

    public Operand<T> getNormalizer() {
        return this.normalizer;
    }
}
