package greycat.ml.neuralnet.layer;

import greycat.ml.common.matrix.MatrixOps;
import greycat.ml.neuralnet.activation.Activation;
import greycat.ml.neuralnet.activation.Activations;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.ml.neuralnet.process.ProcessGraph;
import greycat.struct.ENode;
import java.util.Random;

/* loaded from: input_file:greycat/ml/neuralnet/layer/GRU.class */
class GRU implements Layer {
    private static String IHMIX = "ihmix";
    private static String HHMIX = "hhmix";
    private static String BMIX = "bmix";
    private static String IHNEW = "ihnew";
    private static String HHNEW = "hhnew";
    private static String BNEW = "bnew";
    private static String IHRESET = "ihreset";
    private static String HHRESET = "hhreset";
    private static String BRESET = "breset";
    private static String CONTEXT = "context";
    private ExMatrix ihmix;
    private ExMatrix hhmix;
    private ExMatrix bmix;
    private ExMatrix ihnew;
    private ExMatrix hhnew;
    private ExMatrix bnew;
    private ExMatrix ihreset;
    private ExMatrix hhreset;
    private ExMatrix breset;
    private ExMatrix context;
    private ENode host;
    private Activation fMix = Activations.getUnit(1, null);
    private Activation fReset = Activations.getUnit(1, null);
    private Activation fNew = Activations.getUnit(3, null);
    private ExMatrix[] params = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GRU(ENode eNode) {
        if (eNode == null) {
            throw new RuntimeException("Host node can't be null");
        }
        this.ihmix = new ExMatrix(eNode, IHMIX);
        this.hhmix = new ExMatrix(eNode, HHMIX);
        this.bmix = new ExMatrix(eNode, BMIX);
        this.ihnew = new ExMatrix(eNode, IHNEW);
        this.hhnew = new ExMatrix(eNode, HHNEW);
        this.bnew = new ExMatrix(eNode, BNEW);
        this.ihreset = new ExMatrix(eNode, IHRESET);
        this.hhreset = new ExMatrix(eNode, HHRESET);
        this.breset = new ExMatrix(eNode, BRESET);
        this.context = new ExMatrix(eNode, CONTEXT);
        this.host = eNode;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public Layer init(int i, int i2, int i3, double[] dArr, Random random, double d) {
        this.host.set(Layers.TYPE, (byte) 4, 2);
        this.ihmix.init(i2, i);
        this.hhmix.init(i2, i2);
        this.bmix.init(i2, 1);
        this.ihnew.init(i2, i);
        this.hhnew.init(i2, i2);
        this.bnew.init(i2, 1);
        this.ihreset.init(i2, i);
        this.hhreset.init(i2, i2);
        this.breset.init(i2, 1);
        this.context.init(i2, 1);
        if (random != null && d != 0.0d) {
            MatrixOps.fillWithRandomStd(this.ihmix, random, d);
            MatrixOps.fillWithRandomStd(this.hhmix, random, d);
            MatrixOps.fillWithRandomStd(this.ihnew, random, d);
            MatrixOps.fillWithRandomStd(this.hhnew, random, d);
            MatrixOps.fillWithRandomStd(this.ihreset, random, d);
            MatrixOps.fillWithRandomStd(this.hhreset, random, d);
        }
        return this;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix forward(ExMatrix exMatrix, ProcessGraph processGraph) {
        ExMatrix activate = processGraph.activate(this.fMix, processGraph.add(processGraph.add(processGraph.mul(this.ihmix, exMatrix), processGraph.mul(this.hhmix, this.context)), this.bmix));
        ExMatrix activate2 = processGraph.activate(this.fReset, processGraph.add(processGraph.add(processGraph.mul(this.ihreset, exMatrix), processGraph.mul(this.hhreset, this.context)), this.breset));
        ExMatrix add = processGraph.add(processGraph.elmul(activate, this.context), processGraph.elmul(processGraph.oneMinus(activate), processGraph.activate(this.fNew, processGraph.add(processGraph.add(processGraph.mul(this.ihnew, exMatrix), processGraph.mul(this.hhnew, processGraph.elmul(activate2, this.context))), this.bnew))));
        this.context = add;
        return add;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix[] getLayerParameters() {
        if (this.params == null) {
            this.params = new ExMatrix[]{this.ihmix, this.hhmix, this.bmix, this.ihnew, this.hhnew, this.bnew, this.ihreset, this.hhreset, this.breset};
        }
        return this.params;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public void resetState() {
        this.context.getW().fill(0.0d);
        this.context.getDw().fill(0.0d);
        this.context.getStepCache().fill(0.0d);
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int inputDimension() {
        return this.ihmix.columns();
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int outputDimension() {
        return this.ihmix.rows();
    }
}
