package greycat.ml.neuralnet.learner;

import greycat.ml.common.matrix.MatrixOps;
import greycat.ml.neuralnet.layer.Layer;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.struct.DMatrix;
import greycat.struct.ENode;

/* loaded from: input_file:greycat/ml/neuralnet/learner/GradientDescent.class */
class GradientDescent extends AbstractLearner {
    private static final String LEARNING_RATE = "learningrate";
    private static final double LEARNING_RATE_DEF = 0.001d;
    private static final String REGULARIZATION_RATE = "regularizationrate";
    private static final double REGULARIZATION_RATE_DEF = 1.0E-6d;
    private double learningRate;
    private double regularization;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GradientDescent(ENode eNode) {
        super(eNode);
        this.learningRate = ((Double) eNode.getWithDefault(LEARNING_RATE, Double.valueOf(LEARNING_RATE_DEF))).doubleValue();
        this.regularization = ((Double) eNode.getWithDefault(REGULARIZATION_RATE, Double.valueOf(REGULARIZATION_RATE_DEF))).doubleValue();
    }

    @Override // greycat.ml.neuralnet.learner.AbstractLearner
    protected void update(Layer[] layerArr) {
        double d = 1.0d - ((this.learningRate * this.regularization) / this.steps);
        double d2 = (-this.learningRate) / this.steps;
        for (Layer layer : layerArr) {
            ExMatrix[] layerParameters = layer.getLayerParameters();
            for (int i = 0; i < layerParameters.length; i++) {
                DMatrix w = layerParameters[i].getW();
                DMatrix dw = layerParameters[i].getDw();
                MatrixOps.addInPlace(w, d, dw, d2);
                dw.fill(0.0d);
            }
        }
    }

    @Override // greycat.ml.neuralnet.learner.AbstractLearner, greycat.ml.neuralnet.learner.Learner
    public void setParams(double[] dArr) {
        if (dArr.length != 2) {
            throw new RuntimeException("Gradient descent needs 2 params: {learning rate, regularization rate}");
        }
        this.learningRate = dArr[0];
        this.regularization = dArr[1];
        this._backend.set(LEARNING_RATE, (byte) 5, Double.valueOf(this.learningRate));
        this._backend.set(REGULARIZATION_RATE, (byte) 5, Double.valueOf(this.regularization));
    }
}
