package org.tensorflow.framework.optimizers;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.framework.optimizers.Optimizer;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.train.ApplyAdam;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/framework/optimizers/Adam.class */
public class Adam extends Optimizer {
    public static final String FIRST_MOMENT = "m";
    public static final String SECOND_MOMENT = "v";
    public static final float LEARNING_RATE_DEFAULT = 0.001f;
    public static final float EPSILON_DEFAULT = 1.0E-8f;
    public static final float BETA_ONE_DEFAULT = 0.9f;
    public static final float BETA_TWO_DEFAULT = 0.999f;
    private final float learningRate;
    private final float betaOne;
    private final float betaTwo;
    private final float epsilon;
    private Constant<TFloat32> learningRateConst;
    private Constant<TFloat32> epsilonConst;
    private Constant<TFloat32> betaOneConst;
    private Constant<TFloat32> betaTwoConst;
    private Variable<TFloat32> betaOnePower;
    private Variable<TFloat32> betaTwoPower;

    public Adam(Graph graph) {
        this(graph, 0.001f, 0.9f, 0.999f, 1.0E-8f);
    }

    public Adam(Graph graph, float f) {
        this(graph, f, 0.9f, 0.999f, 1.0E-8f);
    }

    public Adam(Graph graph, float f, float f2, float f3, float f4) {
        super(graph);
        this.learningRate = f;
        this.betaOne = f2;
        this.betaTwo = f3;
        this.epsilon = f4;
    }

    public Adam(Graph graph, String str, float f) {
        this(graph, str, f, 0.9f, 0.999f, 1.0E-8f);
    }

    public Adam(Graph graph, String str, float f, float f2, float f3, float f4) {
        super(graph, str);
        this.learningRate = f;
        this.betaOne = f2;
        this.betaTwo = f3;
        this.epsilon = f4;
    }

    public static <T extends TType> Op createAdamMinimize(Scope scope, Operand<T> operand, float f, float f2, float f3, float f4, Optimizer.Options... optionsArr) {
        if (!(scope.env() instanceof Graph)) {
            throw new IllegalArgumentException("Optimizers are only supported on Graphs");
        }
        Adam adam = new Adam(scope.env(), f, f2, f3, f4);
        String str = null;
        for (Optimizer.Options options : optionsArr) {
            if (options.sharedName != null) {
                str = options.sharedName;
            }
        }
        return str == null ? adam.minimize(operand) : adam.minimize(operand, str);
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected void createSlots(List<Output<? extends TType>> list) {
        Iterator<Output<? extends TType>> it = list.iterator();
        while (it.hasNext()) {
            createAdamSlot(it.next().asOutput());
        }
        this.betaOnePower = this.tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.class, new Variable.Options[0]);
        this.graph.addInitializer(this.tf.assign(this.betaOnePower, this.tf.constant(this.betaOne), new Assign.Options[0]));
        this.betaTwoPower = this.tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.class, new Variable.Options[0]);
        this.graph.addInitializer(this.tf.assign(this.betaTwoPower, this.tf.constant(this.betaTwo), new Assign.Options[0]));
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected Optional<Op> prepare(String str) {
        this.betaOneConst = this.tf.constant(this.betaOne);
        this.betaTwoConst = this.tf.constant(this.betaTwo);
        this.learningRateConst = this.tf.constant(this.learningRate);
        this.epsilonConst = this.tf.constant(this.epsilon);
        return Optional.empty();
    }

    private <T extends TType> void createAdamSlot(Output<T> output) {
        createSlot(output.asOutput(), "m", this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(0.0f), output.type(), new Cast.Options[0])));
        createSlot(output.asOutput(), "v", this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(0.0f), output.type(), new Cast.Options[0])));
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected <T extends TType> Op applyDense(Output<T> output, Output<T> output2) {
        return this.tf.train.applyAdam(output2, getSlot(output2, "m").get(), getSlot(output2, "v").get(), this.tf.dtypes.cast(this.betaOnePower, output.type(), new Cast.Options[0]), this.tf.dtypes.cast(this.betaTwoPower, output.type(), new Cast.Options[0]), this.tf.dtypes.cast(this.learningRateConst, output.type(), new Cast.Options[0]), this.tf.dtypes.cast(this.betaOneConst, output.type(), new Cast.Options[0]), this.tf.dtypes.cast(this.betaTwoConst, output.type(), new Cast.Options[0]), this.tf.dtypes.cast(this.epsilonConst, output.type(), new Cast.Options[0]), output, new ApplyAdam.Options[0]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.tensorflow.framework.optimizers.Optimizer
    public Op finish(List<Op> list, String str) {
        list.add(this.tf.assign(this.betaOnePower, this.tf.math.mul(this.betaOnePower, this.betaOneConst), new Assign.Options[0]));
        list.add(this.tf.assign(this.betaTwoPower, this.tf.math.mul(this.betaTwoPower, this.betaTwoConst), new Assign.Options[0]));
        return super.finish(list, str);
    }

    public String toString() {
        return "Adam{learningRate=" + this.learningRate + ", betaOne=" + this.betaOne + ", betaTwo=" + this.betaTwo + ", epsilon=" + this.epsilon + '}';
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    public String getOptimizerName() {
        return "Adam";
    }
}
