package org.tensorflow.framework.optimizers;

import java.util.Iterator;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.train.ApplyCenteredRmsProp;
import org.tensorflow.op.train.ApplyRmsProp;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/framework/optimizers/RMSProp.class */
public class RMSProp extends Optimizer {
    public static final float LEARNING_RATE_DEFAULT = 0.001f;
    public static final float DECAY_DEFAULT = 0.9f;
    public static final float MOMENTUM_DEFAULT = 0.0f;
    public static final float EPSILON_DEFAULT = 1.0E-10f;
    public static final boolean CENTERED_DEFAULT = false;
    public static final String RMS = "rms";
    public static final String MG = "mg";
    public static final String MOMENTUM = "momentum";
    private final float learningRate;
    private final float decay;
    private final float momentum;
    private final float epsilon;
    private final boolean centered;

    public RMSProp(Graph graph) {
        this(graph, 0.001f, 0.9f, 0.0f, 1.0E-10f, false);
    }

    public RMSProp(Graph graph, float f) {
        this(graph, f, 0.9f, 0.0f, 1.0E-10f, false);
    }

    public RMSProp(Graph graph, float f, float f2, float f3, float f4, boolean z) {
        super(graph);
        this.learningRate = f;
        this.decay = f2;
        this.momentum = f3;
        this.epsilon = f4;
        this.centered = z;
    }

    public RMSProp(Graph graph, String str, float f) {
        this(graph, str, f, 0.9f, 0.0f, 1.0E-10f, false);
    }

    public RMSProp(Graph graph, String str, float f, float f2, float f3, float f4, boolean z) {
        super(graph, str);
        this.learningRate = f;
        this.decay = f2;
        this.momentum = f3;
        this.epsilon = f4;
        this.centered = z;
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected void createSlots(List<Output<? extends TType>> list) {
        Iterator<Output<? extends TType>> it = list.iterator();
        while (it.hasNext()) {
            createRMSPropSlot(it.next());
        }
    }

    private <T extends TType> void createRMSPropSlot(Output<T> output) {
        createSlot(output.asOutput(), RMS, this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(1.0f), output.type(), new Cast.Options[0])));
        createSlot(output.asOutput(), "momentum", this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(0.0f), output.type(), new Cast.Options[0])));
        if (this.centered) {
            createSlot(output.asOutput(), MG, this.tf.fill(this.tf.shape(output), this.tf.dtypes.cast(this.tf.constant(0.0f), output.type(), new Cast.Options[0])));
        }
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    protected <T extends TType> Op applyDense(Ops ops, Output<T> output, Output<T> output2) {
        Variable<T> variable = getSlot(output2, RMS).get();
        Variable<T> variable2 = getSlot(output2, "momentum").get();
        if (!this.centered) {
            return ops.train.applyRmsProp(output2, variable, variable2, ops.dtypes.cast(ops.constant(this.learningRate), output.type(), new Cast.Options[0]), ops.dtypes.cast(ops.constant(this.decay), output.type(), new Cast.Options[0]), ops.dtypes.cast(ops.constant(this.momentum), output.type(), new Cast.Options[0]), ops.dtypes.cast(ops.constant(this.epsilon), output.type(), new Cast.Options[0]), output, new ApplyRmsProp.Options[]{ApplyRmsProp.useLocking(true)});
        }
        return ops.train.applyCenteredRmsProp(output2, getSlot(output2, MG).get(), variable, variable2, ops.dtypes.cast(ops.constant(this.learningRate), output.type(), new Cast.Options[0]), ops.dtypes.cast(ops.constant(this.decay), output.type(), new Cast.Options[0]), ops.dtypes.cast(ops.constant(this.momentum), output.type(), new Cast.Options[0]), ops.dtypes.cast(ops.constant(this.epsilon), output.type(), new Cast.Options[0]), output, new ApplyCenteredRmsProp.Options[]{ApplyCenteredRmsProp.useLocking(true)});
    }

    public String toString() {
        return "RMSProp{learningRate=" + this.learningRate + ", decay=" + this.decay + ", momentum=" + this.momentum + ", epsilon=" + this.epsilon + ", centered=" + this.centered + "}";
    }

    @Override // org.tensorflow.framework.optimizers.Optimizer
    public String getOptimizerName() {
        return "RMSProp";
    }
}
