package org.tensorflow.framework.activations;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReduceMax;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.op.math.Exp;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/activations/Softmax.class */
public class Softmax extends AbstractActivation {
    public static final String NAME = "softmax";
    private static final Set<String> allowedConfigKeys = new HashSet(Arrays.asList("name", "axis"));
    private static final int AXIS_DEFAULT = -1;
    private final int axis;

    public Softmax() {
        this(-1);
    }

    public Softmax(int i) {
        this.axis = i;
    }

    public Softmax(Map<String, Object> map) {
        checkConfigKeys(map.keySet(), allowedConfigKeys);
        checkClassName(map);
        this.axis = ((Integer) map.getOrDefault("axis", -1)).intValue();
    }

    public static <T extends TNumber> Operand<T> softmax(Ops ops, Operand<T> operand) {
        return softmax(ops, operand, ops.constant(-1));
    }

    public static <T extends TNumber> Operand<T> softmax(Ops ops, Operand<T> operand, Operand<TInt32> operand2) {
        if (operand.shape().numDimensions() == 2) {
            return ops.nn.softmax(operand);
        }
        Exp exp = ops.math.exp(ops.math.sub(operand, ops.reduceMax(operand, operand2, new ReduceMax.Options[]{ReduceMax.keepDims(true)})));
        return ops.math.div(exp, ops.reduceSum(exp, operand2, new ReduceSum.Options[]{ReduceSum.keepDims(true)}));
    }

    @Override // org.tensorflow.framework.activations.AbstractActivation
    public Map<String, Object> getConfig() {
        HashMap hashMap = new HashMap();
        hashMap.put("name", NAME);
        hashMap.put("axis", Integer.valueOf(this.axis));
        return hashMap;
    }

    @Override // org.tensorflow.framework.activations.Activation
    public <T extends TNumber> Operand<T> call(Ops ops, Operand<T> operand) {
        return softmax(ops, operand, ops.constant(this.axis));
    }

    @Override // org.tensorflow.framework.activations.AbstractActivation
    public String getName() {
        return NAME;
    }

    public int getAxis() {
        return this.axis;
    }
}
