package greycat.ml.neuralnet.layer;

import greycat.ml.common.matrix.MatrixOps;
import greycat.ml.neuralnet.activation.Activation;
import greycat.ml.neuralnet.activation.Activations;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.ml.neuralnet.process.ProcessGraph;
import greycat.struct.ENode;
import java.util.Random;

/* loaded from: input_file:greycat/ml/neuralnet/layer/RNN.class */
class RNN implements Layer {
    private static String WEIGHTS = "weights";
    private static String BIAS = "bias";
    private static String ACTIVATION = "activation";
    private static String ACTIVATION_PARAM = "activation_param";
    private static String CONTEXT = "context";
    private ExMatrix weights;
    private ExMatrix bias;
    private ExMatrix context;
    private Activation activation;
    private ENode host;
    private ExMatrix[] params = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    public RNN(ENode eNode) {
        if (eNode == null) {
            throw new RuntimeException("Host node can't be null");
        }
        this.weights = new ExMatrix(eNode, WEIGHTS);
        this.bias = new ExMatrix(eNode, BIAS);
        this.context = new ExMatrix(eNode, CONTEXT);
        this.activation = Activations.getUnit(((Integer) eNode.getWithDefault(ACTIVATION, 0)).intValue(), (double[]) eNode.getOrCreate(ACTIVATION_PARAM, (byte) 6));
        this.host = eNode;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public Layer init(int i, int i2, int i3, double[] dArr, Random random, double d) {
        this.host.set(Layers.TYPE, (byte) 4, 4);
        this.weights.init(i2, i + i2);
        this.bias.init(i2, 1);
        this.context.init(i2, 1);
        this.activation = Activations.getUnit(i3, dArr);
        this.host.set(ACTIVATION, (byte) 4, Integer.valueOf(i3));
        if (dArr != null) {
            this.host.set(ACTIVATION_PARAM, (byte) 6, dArr);
        }
        if (random != null && d != 0.0d) {
            MatrixOps.fillWithRandomStd(this.weights, random, d);
        }
        return this;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix forward(ExMatrix exMatrix, ProcessGraph processGraph) {
        ExMatrix activate = processGraph.activate(this.activation, processGraph.add(processGraph.mul(this.weights, processGraph.concatVectors(exMatrix, this.context)), this.bias));
        this.context = activate;
        return activate;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix[] getLayerParameters() {
        if (this.params == null) {
            this.params = new ExMatrix[]{this.weights, this.bias};
        }
        return this.params;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public void resetState() {
        this.context.getW().fill(0.0d);
        this.context.getDw().fill(0.0d);
        this.context.getStepCache().fill(0.0d);
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int inputDimension() {
        return this.weights.columns() - this.weights.rows();
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int outputDimension() {
        return this.weights.rows();
    }
}
