package org.tensorflow.framework.constraints;

import org.tensorflow.Operand;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/constraints/MinMaxNorm.class */
public class MinMaxNorm extends AbstractConstraint {
    public static final double MIN_VALUE_DEFAULT = 0.0d;
    public static final double MAX_VALUE_DEFAULT = 1.0d;
    public static final double RATE_DEFAULT = 1.0d;
    public static final int AXIS_DEFAULT = 0;
    private final double minValue;
    private final double maxValue;
    private final double rate;
    private final int[] axes;

    public MinMaxNorm() {
        this(0.0d, 1.0d, 1.0d, 0);
    }

    public MinMaxNorm(double d, double d2) {
        this(d, d2, 1.0d, 0);
    }

    public MinMaxNorm(double d, double d2, double d3, int i) {
        this(d, d2, d3, new int[]{i});
    }

    public MinMaxNorm(double d, double d2, double d3, int[] iArr) {
        this.minValue = d;
        this.maxValue = d2;
        this.rate = d3;
        this.axes = iArr;
    }

    @Override // org.tensorflow.framework.constraints.Constraint
    public <T extends TNumber> Operand<T> call(Ops ops, Operand<T> operand) {
        Class type = operand.type();
        Operand<T> norm = norm(ops, operand, getAxes());
        return ops.math.mul(operand, ops.math.div(ops.math.add(ops.math.mul(ops.dtypes.cast(ops.constant(getRate()), type, new Cast.Options[0]), clip(ops, norm, getMinValue(), getMaxValue())), ops.math.mul(ops.math.sub(ops.dtypes.cast(ops.constant(1), type, new Cast.Options[0]), ops.dtypes.cast(ops.constant(getRate()), type, new Cast.Options[0])), norm)), ops.math.add(CastHelper.cast(ops, ops.constant(1.0E-7f), type), norm)));
    }

    public double getMinValue() {
        return this.minValue;
    }

    public double getMaxValue() {
        return this.maxValue;
    }

    public double getRate() {
        return this.rate;
    }

    public int[] getAxes() {
        return this.axes;
    }
}
