package org.tensorflow.framework.activations;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.tensorflow.Operand;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.op.nn.Elu;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/activations/ELU.class */
public class ELU extends AbstractActivation {
    public static final String NAME = "elu";
    private static final Set<String> allowedConfigKeys = new HashSet(Arrays.asList("name", "alpha"));
    private static final double ALPHA_DEFAULT = 1.0d;
    private final double alpha;

    public ELU() {
        this(1.0d);
    }

    public ELU(double d) {
        this.alpha = d;
    }

    public ELU(Map<String, Object> map) {
        checkConfigKeys(map.keySet(), allowedConfigKeys);
        checkClassName(map);
        this.alpha = ((Number) map.getOrDefault("alpha", Double.valueOf(1.0d))).doubleValue();
    }

    public static <T extends TNumber> Operand<T> elu(Ops ops, Operand<T> operand, double d) {
        Elu elu = ops.nn.elu(operand);
        if (d == 1.0d) {
            return elu;
        }
        Class type = operand.type();
        return ops.select(ops.math.greater(elu, CastHelper.cast(ops, ops.constant(0), type)), elu, ops.math.mul(elu, CastHelper.cast(ops, ops.constant(d), type)));
    }

    @Override // org.tensorflow.framework.activations.AbstractActivation
    public Map<String, Object> getConfig() {
        HashMap hashMap = new HashMap();
        hashMap.put("name", NAME);
        hashMap.put("alpha", Double.valueOf(this.alpha));
        return hashMap;
    }

    @Override // org.tensorflow.framework.activations.Activation
    public <T extends TNumber> Operand<T> call(Ops ops, Operand<T> operand) {
        return elu(ops, operand, this.alpha);
    }

    @Override // org.tensorflow.framework.activations.AbstractActivation
    public String getName() {
        return NAME;
    }

    public double getAlpha() {
        return this.alpha;
    }
}
