package org.tensorflow.framework.activations;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.tensorflow.Operand;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.nn.LeakyRelu;
import org.tensorflow.op.nn.Relu;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/activations/ReLU.class */
public class ReLU extends AbstractActivation {
    public static final String NAME = "relu";
    public static final float ALPHA_DEFAULT = 0.0f;
    public static final float MAX_VALUE_DEFAULT = Float.NaN;
    public static final float THRESHOLD_DEFAULT = 0.0f;
    private static final Set<String> allowedConfigKeys = new HashSet(Arrays.asList("name", "alpha", "max_value", "threshold"));
    private final float alpha;
    private final float maxValue;
    private final float threshold;

    public ReLU() {
        this(0.0f, Float.NaN, 0.0f);
    }

    public ReLU(float f, float f2, float f3) {
        this.alpha = f;
        this.maxValue = f2;
        this.threshold = f3;
    }

    public ReLU(Map<String, Object> map) {
        checkConfigKeys(map.keySet(), allowedConfigKeys);
        checkClassName(map);
        this.alpha = ((Number) map.getOrDefault("alpha", Float.valueOf(0.0f))).floatValue();
        this.maxValue = ((Number) map.getOrDefault("max_value", Float.valueOf(Float.NaN))).floatValue();
        this.threshold = ((Number) map.getOrDefault("threshold", Float.valueOf(0.0f))).floatValue();
    }

    public static <T extends TNumber> Operand<T> relu(Ops ops, Operand<T> operand) {
        return relu(ops, operand, 0.0f, Float.NaN, 0.0f);
    }

    public static <T extends TNumber> Operand<T> relu(Ops ops, Operand<T> operand, float f, float f2, float f3) {
        Mul relu;
        Class type = operand.type();
        boolean z = !Float.isNaN(f2);
        Relu relu2 = null;
        if (f != 0.0f) {
            if (Float.isNaN(f2) && f3 == 0.0f) {
                return ops.nn.leakyRelu(operand, new LeakyRelu.Options[]{LeakyRelu.alpha(Float.valueOf(f))});
            }
            relu2 = f3 != 0.0f ? ops.nn.relu(ops.math.add(ops.math.neg(operand), CastHelper.cast(ops, ops.constant(f3), type))) : ops.nn.relu(ops.math.neg(operand));
        }
        if (f3 != 0.0f) {
            relu = ops.math.mul(operand, CastHelper.cast(ops, ops.math.greater(operand, CastHelper.cast(ops, ops.constant(f3), type)), type));
        } else if (f2 == 6.0f) {
            relu = ops.nn.relu6(operand);
            z = false;
        } else {
            relu = ops.nn.relu(operand);
        }
        if (z) {
            relu = ops.clipByValue(relu, CastHelper.cast(ops, ops.constant(0), type), CastHelper.cast(ops, ops.constant(f2), type));
        }
        if (f != 0.0d) {
            relu = ops.math.sub(relu, ops.math.mul(CastHelper.cast(ops, ops.constant(f), type), relu2));
        }
        return relu;
    }

    @Override // org.tensorflow.framework.activations.AbstractActivation
    public Map<String, Object> getConfig() {
        HashMap hashMap = new HashMap();
        hashMap.put("name", NAME);
        hashMap.put("alpha", Float.valueOf(this.alpha));
        hashMap.put("max_value", Float.valueOf(this.maxValue));
        hashMap.put("threshold", Float.valueOf(this.threshold));
        return hashMap;
    }

    @Override // org.tensorflow.framework.activations.Activation
    public <T extends TNumber> Operand<T> call(Ops ops, Operand<T> operand) {
        return relu(ops, operand, this.alpha, this.maxValue, this.threshold);
    }

    @Override // org.tensorflow.framework.activations.AbstractActivation
    public String getName() {
        return NAME;
    }

    public float getAlpha() {
        return this.alpha;
    }

    public float getThreshold() {
        return this.threshold;
    }

    public float getMaxValue() {
        return this.maxValue;
    }
}
