package org.tensorflow.framework.metrics;

import java.util.ArrayList;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.framework.initializers.Zeros;
import org.tensorflow.framework.losses.impl.LossTuple;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.metrics.impl.WeightsBroadcastOps;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.AssignAdd;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.math.Mean;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/metrics/MeanTensor.class */
public class MeanTensor<T extends TNumber> extends BaseMetric {
    public static final String TOTAL = "total";
    public static final String COUNT = "count";
    private final String totalName;
    private final String countName;
    private final Class<T> type;
    private Shape shape;
    private Variable<T> total;
    private Variable<T> count;
    private Assign<T> totalInitializer;
    private Assign<T> countInitializer;

    public MeanTensor(long j, Class<T> cls) {
        this(null, j, cls);
    }

    public MeanTensor(String str, long j, Class<T> cls) {
        super(str, j);
        this.type = cls;
        this.totalName = getVariableName("total");
        this.countName = getVariableName("count");
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric
    protected void init(Ops ops) {
        checkIsGraph(ops);
        if (isInitialized() || this.shape == null) {
            return;
        }
        setTF(ops);
        Operand call = new Zeros().call(ops, ops.constant(this.shape), this.type);
        if (this.total == null) {
            this.total = ops.withName(this.totalName).withInitScope().variable(call, new Variable.Options[0]);
            this.totalInitializer = ops.assign(this.total, call, new Assign.Options[0]);
        }
        if (this.count == null) {
            this.count = ops.withName(this.countName).withInitScope().variable(call, new Variable.Options[0]);
            this.countInitializer = ops.assign(this.count, call, new Assign.Options[0]);
        }
        setInitialized(true);
    }

    @Override // org.tensorflow.framework.metrics.BaseMetric, org.tensorflow.framework.metrics.Metric
    public List<Op> updateStateList(Ops ops, Operand<? extends TNumber> operand, Operand<? extends TNumber> operand2) {
        if (this.shape == null) {
            this.shape = operand.shape();
        }
        init(ops);
        Operand cast = CastHelper.cast(ops, operand, this.type);
        Operand cast2 = operand2 == null ? null : CastHelper.cast(ops, operand2, this.type);
        if (!this.shape.equals(operand.shape())) {
            throw new IllegalArgumentException(String.format("MeanTensor input values must always have the same shape. Expected shape (set during the first call): %s. Got %s", this.shape.toString(), operand.shape().toString()));
        }
        Operand onesLike = ops.onesLike(cast);
        if (cast2 != null) {
            LossTuple squeezeOrExpandDimensions = LossesHelper.squeezeOrExpandDimensions(ops, null, cast, cast2);
            Operand target = squeezeOrExpandDimensions.getTarget();
            Operand<T> sampleWeights = squeezeOrExpandDimensions.getSampleWeights();
            try {
                sampleWeights = WeightsBroadcastOps.broadcastWeights(ops, sampleWeights, target);
            } catch (IllegalArgumentException e) {
                int numDimensions = operand.shape().numDimensions();
                int numDimensions2 = sampleWeights.asOutput().shape().numDimensions();
                int[] iArr = new int[numDimensions - numDimensions2];
                for (int i = numDimensions2; i < numDimensions; i++) {
                    iArr[i] = i;
                }
                target = ops.math.mean(target, ops.constant(iArr), new Mean.Options[0]);
            }
            onesLike = ops.math.mul(onesLike, sampleWeights);
            cast = ops.math.mul(target, sampleWeights);
        }
        Ops withSubScope = ops.withSubScope("MeanTensor.variables");
        ArrayList arrayList = new ArrayList();
        arrayList.add(withSubScope.assignAdd(this.count, onesLike, new AssignAdd.Options[0]));
        arrayList.add(withSubScope.assignAdd(this.total, cast, new AssignAdd.Options[0]));
        return arrayList;
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public <U extends TNumber> Operand<U> result(Ops ops, Class<U> cls) {
        init(ops);
        return !isInitialized() ? CastHelper.cast(ops, ops.constant(0), cls) : CastHelper.cast(ops, ops.math.divNoNan(this.total, this.count), cls);
    }

    public Variable<T> getTotal() {
        return this.total;
    }

    public Variable<T> getCount() {
        return this.count;
    }

    @Override // org.tensorflow.framework.metrics.Metric
    public Op resetStates(Ops ops) {
        init(ops);
        if (!isInitialized()) {
            return ops.withSubScope("resetStates").noOp();
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.countInitializer);
        arrayList.add(this.totalInitializer);
        return ops.withSubScope("resetStates").withControlDependencies(arrayList).noOp();
    }
}
