package org.tensorflow.framework.optimizers;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.math.Mul;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/framework/optimizers/Nadam.class */
public class Nadam extends Optimizer {
    public static final float LEARNING_RATE_DEFAULT = 0.001f;
    public static final float EPSILON_DEFAULT = 1.0E-8f;
    public static final float BETA_ONE_DEFAULT = 0.9f;
    public static final float BETA_TWO_DEFAULT = 0.999f;
    public static final String FIRST_MOMENT = "m";
    public static final String SECOND_MOMENT = "v";
    public static final String MOMENTUM = "momentum";
    private static final float DECAY_BASE = 0.96f;
    private static final float DECAY = 0.004f;
    private final float learningRate;
    private final float betaOne;
    private final float betaTwo;
    private final float epsilon;
    private Constant<TFloat32> learningRateConst;
    private Constant<TFloat32> epsilonConst;
    private Constant<TFloat32> betaOneConst;
    private Constant<TFloat32> betaTwoConst;
    private Variable<TFloat32> betaOnePower;
    private Variable<TFloat32> betaTwoPower;
    private Variable<TFloat32> momentum;
    private long iterations;
    private Operand<TFloat32> mT1;
    private Operand<TFloat32> oneMinusBeta1;
    private Operand<TFloat32> oneMinusBeta2;
    private Operand<TFloat32> oneMinusMT;
    private Operand<TFloat32> oneMinusMScheduleNew;
    private Operand<TFloat32> oneMinusMScheduleNext;
    private Operand<TFloat32> vTPrimeDenominator;

    public Nadam(Graph graph) {
        this(graph, 0.001f, 0.9f, 0.999f, 1.0E-8f);
    }

    public Nadam(Graph graph, float f) {
        this(graph, f, 0.9f, 0.999f, 1.0E-8f);
    }

    public Nadam(Graph graph, float f, float f2, float f3, float f4) {
        super(graph);
        this.iterations = 0L;
        this.learningRate = f;
        this.betaOne = f2;
        this.betaTwo = f3;
        this.epsilon = f4;
    }

    public Nadam(Graph graph, String str, float f) {
        this(graph, str, f, 0.9f, 0.999f, 1.0E-8f);
    }

    public Nadam(Graph graph, String str, float f, float f2, float f3, float f4) {
        super(graph, str);
        this.iterations = 0L;
        this.learningRate = f;
        this.betaOne = f2;
        this.betaTwo = f3;
        this.epsilon = f4;
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected void createSlots(List<Output<? extends TType>> list) {
        Iterator<Output<? extends TType>> it = list.iterator();
        while (it.hasNext()) {
            createNadamSlot(it.next().asOutput());
        }
        this.betaOnePower = this.tf.withInitScope().withName("beta1_power").variable(Shape.scalar(), TFloat32.class, new Variable.Options[0]);
        this.tf.withInitScope().assign(this.betaOnePower, this.tf.constant(this.betaOne), new Assign.Options[0]);
        this.betaTwoPower = this.tf.withInitScope().withName("beta2_power").variable(Shape.scalar(), TFloat32.class, new Variable.Options[0]);
        this.tf.withInitScope().assign(this.betaTwoPower, this.tf.constant(this.betaTwo), new Assign.Options[0]);
        this.momentum = this.tf.withInitScope().withName("momentum").variable(Shape.scalar(), TFloat32.class, new Variable.Options[0]);
        this.tf.withInitScope().assign(this.momentum, this.tf.constant(1.0f), new Assign.Options[0]);
    }

    private <T extends TType> void createNadamSlot(Output<T> output) {
        createSlot(output.asOutput(), "m", this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(0.0f), output.type(), new Cast.Options[0])));
        createSlot(output.asOutput(), "v", this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(0.0f), output.type(), new Cast.Options[0])));
        createSlot(output.asOutput(), "momentum", this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(1.0f), output.type(), new Cast.Options[0])));
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected Optional<Op> prepare(String str) {
        Constant constant = this.tf.constant(1.0f);
        Constant constant2 = this.tf.constant(0.5f);
        this.learningRateConst = this.tf.constant(this.learningRate);
        this.betaOneConst = this.tf.constant(this.betaOne);
        this.betaTwoConst = this.tf.constant(this.betaTwo);
        Constant constant3 = this.tf.constant(this.iterations + 1);
        Constant constant4 = this.tf.constant(this.iterations + 2);
        Constant constant5 = this.tf.constant(DECAY);
        Constant constant6 = this.tf.constant(DECAY_BASE);
        this.epsilonConst = this.tf.constant(this.epsilon);
        Mul mul = this.tf.math.mul(this.betaOneConst, this.tf.math.sub(constant, this.tf.math.mul(constant2, this.tf.math.pow(constant6, this.tf.math.mul(constant5, this.tf.dtypes.cast(constant3, TFloat32.class, new Cast.Options[0]))))));
        this.mT1 = this.tf.math.mul(this.betaOneConst, this.tf.math.sub(constant, this.tf.math.mul(constant2, this.tf.math.pow(constant6, this.tf.math.mul(constant5, this.tf.dtypes.cast(constant4, TFloat32.class, new Cast.Options[0]))))));
        Assign assign = this.tf.assign(this.momentum, this.tf.math.mul(this.momentum, mul), new Assign.Options[]{Assign.useLocking(true)});
        Mul mul2 = this.tf.math.mul(assign, this.mT1);
        this.oneMinusBeta1 = this.tf.math.sub(constant, this.betaOneConst);
        this.oneMinusBeta2 = this.tf.math.sub(constant, this.betaTwoConst);
        this.oneMinusMT = this.tf.math.sub(constant, mul);
        this.oneMinusMScheduleNew = this.tf.math.sub(constant, assign);
        this.oneMinusMScheduleNext = this.tf.math.sub(constant, mul2);
        this.vTPrimeDenominator = this.tf.math.sub(constant, this.tf.math.pow(this.betaTwoConst, this.tf.dtypes.cast(constant3, TFloat32.class, new Cast.Options[0])));
        return Optional.empty();
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected <T extends TType> Op applyDense(Ops ops, Output<T> output, Output<T> output2) {
        Class type = output.type();
        Variable<T> variable = getSlot(output2, "m").get();
        Variable<T> variable2 = getSlot(output2, "v").get();
        return ops.assign(output2, ops.math.sub(output2, ops.math.div(ops.math.mul(ops.dtypes.cast(this.learningRateConst, type, new Cast.Options[0]), ops.math.add(ops.math.mul(ops.dtypes.cast(this.oneMinusMT, type, new Cast.Options[0]), ops.math.div(output, ops.dtypes.cast(this.oneMinusMScheduleNew, type, new Cast.Options[0]))), ops.math.mul(ops.dtypes.cast(this.mT1, type, new Cast.Options[0]), ops.math.div(ops.assign(variable, ops.math.add(ops.math.mul(ops.dtypes.cast(this.betaOneConst, type, new Cast.Options[0]), variable), ops.math.mul(ops.dtypes.cast(this.oneMinusBeta1, type, new Cast.Options[0]), output)), new Assign.Options[]{Assign.useLocking(true)}), ops.dtypes.cast(this.oneMinusMScheduleNext, type, new Cast.Options[0]))))), ops.math.add(ops.math.sqrt(ops.math.div(ops.assign(variable2, ops.math.add(ops.math.mul(ops.dtypes.cast(this.betaTwoConst, type, new Cast.Options[0]), variable2), ops.math.mul(ops.dtypes.cast(this.oneMinusBeta2, type, new Cast.Options[0]), ops.math.square(output))), new Assign.Options[]{Assign.useLocking(true)}), ops.dtypes.cast(this.vTPrimeDenominator, type, new Cast.Options[0]))), ops.dtypes.cast(this.epsilonConst, type, new Cast.Options[0])))), new Assign.Options[]{Assign.useLocking(true)});
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.tensorflow.framework.optimizers.Optimizer
    public Op finish(List<Op> list, String str) {
        this.iterations++;
        list.add(this.tf.assign(this.betaOnePower, this.tf.math.mul(this.betaOnePower, this.betaOneConst), new Assign.Options[0]));
        list.add(this.tf.assign(this.betaTwoPower, this.tf.math.mul(this.betaTwoPower, this.betaTwoConst), new Assign.Options[0]));
        return super.finish(list, str);
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    public String getOptimizerName() {
        return "Nadam";
    }
}
