package org.tensorflow.framework.constraints;

import org.tensorflow.Operand;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/constraints/MaxNorm.class */
public class MaxNorm extends AbstractConstraint {
    public static final double MAX_VALUE_DEFAULT = 2.0d;
    public static final int AXIS_DEFAULT = 0;
    private final double maxValue;
    private final int[] axes;

    public MaxNorm() {
        this(2.0d, 0);
    }

    public MaxNorm(double d) {
        this(d, 0);
    }

    public MaxNorm(double d, int i) {
        this(d, new int[]{i});
    }

    public MaxNorm(double d, int[] iArr) {
        this.maxValue = d;
        this.axes = iArr;
    }

    @Override // org.tensorflow.framework.constraints.Constraint
    public <T extends TNumber> Operand<T> call(Ops ops, Operand<T> operand) {
        Class type = operand.type();
        Operand<T> norm = norm(ops, operand, getAxes());
        return ops.math.mul(operand, ops.math.div(clip(ops, norm, 0.0d, getMaxValue()), ops.math.add(CastHelper.cast(ops, ops.constant(1.0E-7f), type), norm)));
    }

    public double getMaxValue() {
        return this.maxValue;
    }

    public int[] getAxes() {
        return this.axes;
    }
}
