package greycat.ml.neuralnet.learner;

import greycat.ml.common.matrix.MatrixOps;
import greycat.ml.neuralnet.layer.Layer;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.struct.DMatrix;
import greycat.struct.ENode;

/* loaded from: input_file:greycat/ml/neuralnet/learner/Momentum.class */
class Momentum extends AbstractLearner {
    static final String DECAY_RATE = "decayrate";
    static final double DECAY_RATE_DEF = 0.9d;
    double decayRate;

    /* JADX INFO: Access modifiers changed from: package-private */
    public Momentum(ENode eNode) {
        super(eNode);
        this.decayRate = ((Double) eNode.getWithDefault(DECAY_RATE, Double.valueOf(DECAY_RATE_DEF))).doubleValue();
    }

    @Override // greycat.ml.neuralnet.learner.AbstractLearner, greycat.ml.neuralnet.learner.Learner
    public void setParams(double[] dArr) {
        if (dArr.length != 3) {
            throw new RuntimeException("Gradient descent needs 3 params: {learning rate, regularization rate, decay Rate}");
        }
        this.learningRate = dArr[0];
        this.regularization = dArr[1];
        this.decayRate = dArr[2];
        this._backend.set("learningrate", (byte) 5, Double.valueOf(this.learningRate));
        this._backend.set("regularizationrate", (byte) 5, Double.valueOf(this.regularization));
        this._backend.set(DECAY_RATE, (byte) 5, Double.valueOf(this.decayRate));
    }

    @Override // greycat.ml.neuralnet.learner.AbstractLearner
    protected void update(Layer[] layerArr) {
        double d = 1.0d - ((this.learningRate * this.regularization) / this.steps);
        double d2 = this.learningRate / this.steps;
        for (Layer layer : layerArr) {
            ExMatrix[] layerParameters = layer.getLayerParameters();
            for (int i = 0; i < layerParameters.length; i++) {
                DMatrix w = layerParameters[i].getW();
                DMatrix dw = layerParameters[i].getDw();
                DMatrix stepCache = layerParameters[i].getStepCache();
                MatrixOps.addInPlace(stepCache, this.decayRate, dw, d2);
                MatrixOps.addInPlace(w, d, stepCache, -1.0d);
                dw.fill(0.0d);
            }
        }
    }
}
