package greycat.ml.neuralnet.learner;

import greycat.ml.neuralnet.layer.Layer;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.struct.DMatrix;
import greycat.struct.ENode;

/* loaded from: input_file:greycat/ml/neuralnet/learner/RMSProp.class */
class RMSProp extends AbstractLearner {
    static final String SMOOTH_EPSILON = "smoothepsilon";
    static final double SMOOTH_EPSILON_DEF = 1.0E-8d;
    static final String DECAY_RATE = "decayrate";
    static final double DECAY_RATE_DEF = 0.9999d;
    public static final String GRADIENT_CLIP_RATE = "gradientclip";
    public static final double GRADIENT_CLIP_DEF = 5.0d;
    private double smoothEpsilon;
    private double decayRate;
    private double gradientClip;

    public RMSProp(ENode eNode) {
        super(eNode);
        this.smoothEpsilon = ((Double) eNode.getWithDefault(SMOOTH_EPSILON, Double.valueOf(SMOOTH_EPSILON_DEF))).doubleValue();
        this.decayRate = ((Double) eNode.getWithDefault(DECAY_RATE, Double.valueOf(DECAY_RATE_DEF))).doubleValue();
        this.gradientClip = ((Double) eNode.getWithDefault(GRADIENT_CLIP_RATE, Double.valueOf(5.0d))).doubleValue();
    }

    @Override // greycat.ml.neuralnet.learner.AbstractLearner
    protected void update(Layer[] layerArr) {
        double d = 1.0d - ((this.learningRate * this.regularization) / this.steps);
        for (Layer layer : layerArr) {
            ExMatrix[] layerParameters = layer.getLayerParameters();
            for (int i = 0; i < layerParameters.length; i++) {
                DMatrix w = layerParameters[i].getW();
                DMatrix dw = layerParameters[i].getDw();
                DMatrix stepCache = layerParameters[i].getStepCache();
                int length = w.length();
                for (int i2 = 0; i2 < length; i2++) {
                    double unsafeGet = dw.unsafeGet(i2) / this.steps;
                    stepCache.unsafeSet(i2, (stepCache.unsafeGet(i2) * this.decayRate) + ((1.0d - this.decayRate) * unsafeGet * unsafeGet));
                    if (unsafeGet > this.gradientClip) {
                        unsafeGet = this.gradientClip;
                    }
                    if (unsafeGet < (-this.gradientClip)) {
                        unsafeGet = -this.gradientClip;
                    }
                    w.unsafeSet(i2, (w.unsafeGet(i2) * d) - ((this.learningRate * unsafeGet) / Math.sqrt(stepCache.unsafeGet(i2) + this.smoothEpsilon)));
                }
                dw.fill(0.0d);
            }
        }
    }

    @Override // greycat.ml.neuralnet.learner.AbstractLearner, greycat.ml.neuralnet.learner.Learner
    public void setParams(double[] dArr) {
        if (dArr.length != 5) {
            throw new RuntimeException("Gradient descent needs 5 params: {learning rate, regularization rate, smooth Epsilon, decay Rate, gradient Clip}");
        }
        this.learningRate = dArr[0];
        this.regularization = dArr[1];
        this.smoothEpsilon = dArr[2];
        this.decayRate = dArr[3];
        this.gradientClip = dArr[4];
        this._backend.set("learningrate", (byte) 5, Double.valueOf(this.learningRate));
        this._backend.set("regularizationrate", (byte) 5, Double.valueOf(this.regularization));
        this._backend.set(SMOOTH_EPSILON, (byte) 5, Double.valueOf(this.smoothEpsilon));
        this._backend.set(DECAY_RATE, (byte) 5, Double.valueOf(this.decayRate));
        this._backend.set(GRADIENT_CLIP_RATE, (byte) 5, Double.valueOf(this.gradientClip));
    }
}
