package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.optimisers.util.ShrinkingMatrix;
import org.tribuo.math.optimisers.util.ShrinkingVector;

/* loaded from: input_file:org/tribuo/math/optimisers/RMSProp.class */
public class RMSProp implements StochasticGradientOptimiser {
    private static final Logger logger = Logger.getLogger(RMSProp.class.getName());

    @Config(mandatory = true, description = "Learning rate to scale the gradients by.")
    private double initialLearningRate;

    @Config(description = "Momentum parameter.")
    private double rho;

    @Config(description = "Epsilon for numerical stability.")
    private double epsilon;

    @Config(description = "Decay factor for the momentum.")
    private double decay;
    private double invRho;
    private int iteration;
    private Tensor[] gradsSquared;
    private DoubleUnaryOperator square;

    public RMSProp(double d, double d2, double d3, double d4) {
        this.rho = 0.9d;
        this.epsilon = 1.0E-8d;
        this.decay = 0.0d;
        this.iteration = 0;
        this.initialLearningRate = d;
        this.rho = d2;
        this.epsilon = d3;
        this.decay = d4;
        this.iteration = 0;
        postConfig();
    }

    public RMSProp(double d, double d2) {
        this(d, d2, 1.0E-8d, 0.0d);
    }

    private RMSProp() {
        this.rho = 0.9d;
        this.epsilon = 1.0E-8d;
        this.decay = 0.0d;
        this.iteration = 0;
    }

    public void postConfig() {
        this.invRho = 1.0d - this.rho;
        this.square = d -> {
            return this.invRho * d * d;
        };
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void initialise(Parameters parameters) {
        this.gradsSquared = parameters.getEmptyCopy();
        for (int i = 0; i < this.gradsSquared.length; i++) {
            if (this.gradsSquared[i] instanceof DenseVector) {
                this.gradsSquared[i] = new ShrinkingVector((DenseVector) this.gradsSquared[i], this.invRho, false);
            } else {
                if (!(this.gradsSquared[i] instanceof DenseMatrix)) {
                    throw new IllegalStateException("Unknown Tensor subclass");
                }
                this.gradsSquared[i] = new ShrinkingMatrix((DenseMatrix) this.gradsSquared[i], this.invRho, false);
            }
        }
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Tensor[] step(Tensor[] tensorArr, double d) {
        double d2 = this.initialLearningRate / (1.0d + (this.decay * this.iteration));
        DoubleUnaryOperator doubleUnaryOperator = d3 -> {
            return (d * d2) / (this.epsilon + Math.sqrt(d3));
        };
        for (int i = 0; i < tensorArr.length; i++) {
            Tensor tensor = this.gradsSquared[i];
            Tensor tensor2 = tensorArr[i];
            tensor.intersectAndAddInPlace(tensor2, this.square);
            tensor2.hadamardProductInPlace(tensor, doubleUnaryOperator);
        }
        this.iteration++;
        return tensorArr;
    }

    public String toString() {
        return "RMSProp(initialLearningRate=" + this.initialLearningRate + ",rho=" + this.rho + ",epsilon=" + this.epsilon + ",decay=" + this.decay + ")";
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void reset() {
        this.gradsSquared = null;
        this.iteration = 0;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public RMSProp copy() {
        return new RMSProp(this.initialLearningRate, this.rho, this.epsilon, this.decay);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m39getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "StochasticGradientOptimiser");
    }
}
