package org.tensorflow.framework.regularizers;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/regularizers/L1L2.class */
public class L1L2 extends AbstractRegularizer {
    private final float l1;
    private final float l2;

    public L1L2() {
        this(0.01f, 0.01f);
    }

    public L1L2(float f, float f2) {
        this(null, f, f2);
    }

    public L1L2(String str, float f, float f2) {
        super(str);
        if (Float.isNaN(f) || Float.isInfinite(f)) {
            throw new IllegalArgumentException(String.format("L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", Float.valueOf(f)));
        }
        this.l1 = f;
        if (Float.isNaN(f2) || Float.isInfinite(f2)) {
            throw new IllegalArgumentException(String.format("L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", Float.valueOf(f2)));
        }
        this.l2 = f2;
    }

    @Override // org.tensorflow.framework.regularizers.Regularizer
    public <R extends TNumber> Operand<R> call(Ops ops, Operand<R> operand) {
        if (getL1() == 0.0f && getL2() == 0.0f) {
            return CastHelper.cast(ops, ops.constant(0), operand.type());
        }
        Operand cast = CastHelper.cast(ops, ops.constant(0), operand.type());
        if (getL1() != 0.0f) {
            cast = ops.math.add(cast, ops.math.mul(CastHelper.cast(ops, ops.constant(getL1()), operand.type()), ops.reduceSum(ops.math.abs(operand), LossesHelper.allAxes(ops, operand), new ReduceSum.Options[0])));
        }
        if (getL2() != 0.0f) {
            cast = ops.math.add(cast, ops.math.mul(CastHelper.cast(ops, ops.constant(getL2()), operand.type()), ops.reduceSum(ops.math.square(operand), LossesHelper.allAxes(ops, operand), new ReduceSum.Options[0])));
        }
        return cast;
    }

    public float getL1() {
        return this.l1;
    }

    public float getL2() {
        return this.l2;
    }
}
