package greycat.ml.neuralnet.layer;

import greycat.ml.common.matrix.MatrixOps;
import greycat.ml.neuralnet.activation.Activation;
import greycat.ml.neuralnet.activation.Activations;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.ml.neuralnet.process.ProcessGraph;
import greycat.struct.ENode;
import java.util.Random;

/* loaded from: input_file:greycat/ml/neuralnet/layer/LSTM.class */
class LSTM implements Layer {
    private static String WIX = "wix";
    private static String WIH = "wih";
    private static String BI = "bi";
    private static String WFX = "wfx";
    private static String WFH = "wfh";
    private static String BF = "bf";
    private static String WOX = "wox";
    private static String WOH = "woh";
    private static String BO = "bo";
    private static String WCX = "wcx";
    private static String WCH = "wch";
    private static String BC = "bc";
    private static String HIDDEN_CONTEXT = "hiddencontext";
    private static String CELL_CONTEXT = "cellcontext";
    private ExMatrix wix;
    private ExMatrix wih;
    private ExMatrix bi;
    private ExMatrix wfx;
    private ExMatrix wfh;
    private ExMatrix bf;
    private ExMatrix wox;
    private ExMatrix woh;
    private ExMatrix bo;
    private ExMatrix wcx;
    private ExMatrix wch;
    private ExMatrix bc;
    private ExMatrix hiddenContext;
    private ExMatrix cellContext;
    private ENode host;
    private Activation fInputGate = Activations.getUnit(1, null);
    private Activation fForgetGate = Activations.getUnit(1, null);
    private Activation fOutputGate = Activations.getUnit(1, null);
    private Activation fCellInput = Activations.getUnit(3, null);
    private Activation fCellOutput = Activations.getUnit(3, null);
    private ExMatrix[] params = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LSTM(ENode eNode) {
        if (eNode == null) {
            throw new RuntimeException("Host node can't be null");
        }
        this.wix = new ExMatrix(eNode, WIX);
        this.wih = new ExMatrix(eNode, WIH);
        this.bi = new ExMatrix(eNode, BI);
        this.wfx = new ExMatrix(eNode, WFX);
        this.wfh = new ExMatrix(eNode, WFH);
        this.bf = new ExMatrix(eNode, BF);
        this.wox = new ExMatrix(eNode, WOX);
        this.woh = new ExMatrix(eNode, WOH);
        this.bo = new ExMatrix(eNode, BO);
        this.wcx = new ExMatrix(eNode, WCX);
        this.wch = new ExMatrix(eNode, WCH);
        this.bc = new ExMatrix(eNode, BC);
        this.hiddenContext = new ExMatrix(eNode, HIDDEN_CONTEXT);
        this.cellContext = new ExMatrix(eNode, CELL_CONTEXT);
        this.host = eNode;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public Layer init(int i, int i2, int i3, double[] dArr, Random random, double d) {
        this.host.set(Layers.TYPE, (byte) 4, 3);
        this.wix.init(i2, i);
        this.wih.init(i2, i2);
        this.bi.init(i2, 1);
        this.wfx.init(i2, i);
        this.wfh.init(i2, i2);
        this.bf.init(i2, 1);
        this.wox.init(i2, i);
        this.woh.init(i2, i2);
        this.bo.init(i2, 1);
        this.wcx.init(i2, i);
        this.wch.init(i2, i2);
        this.bc.init(i2, 1);
        this.hiddenContext.init(i2, 1);
        this.cellContext.init(i2, 1);
        if (random != null && d != 0.0d) {
            MatrixOps.fillWithRandomStd(this.wix, random, d);
            MatrixOps.fillWithRandomStd(this.wih, random, d);
            MatrixOps.fillWithRandomStd(this.wfx, random, d);
            MatrixOps.fillWithRandomStd(this.wfh, random, d);
            this.bf.fill(1.0d);
            MatrixOps.fillWithRandomStd(this.wox, random, d);
            MatrixOps.fillWithRandomStd(this.woh, random, d);
            MatrixOps.fillWithRandomStd(this.wcx, random, d);
            MatrixOps.fillWithRandomStd(this.wch, random, d);
        }
        return this;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix forward(ExMatrix exMatrix, ProcessGraph processGraph) {
        ExMatrix activate = processGraph.activate(this.fInputGate, processGraph.add(processGraph.add(processGraph.mul(this.wix, exMatrix), processGraph.mul(this.wih, this.hiddenContext)), this.bi));
        ExMatrix activate2 = processGraph.activate(this.fForgetGate, processGraph.add(processGraph.add(processGraph.mul(this.wfx, exMatrix), processGraph.mul(this.wfh, this.hiddenContext)), this.bf));
        ExMatrix activate3 = processGraph.activate(this.fOutputGate, processGraph.add(processGraph.add(processGraph.mul(this.wox, exMatrix), processGraph.mul(this.woh, this.hiddenContext)), this.bo));
        ExMatrix add = processGraph.add(processGraph.elmul(activate2, this.cellContext), processGraph.elmul(activate, processGraph.activate(this.fCellInput, processGraph.add(processGraph.add(processGraph.mul(this.wcx, exMatrix), processGraph.mul(this.wch, this.hiddenContext)), this.bc))));
        ExMatrix elmul = processGraph.elmul(activate3, processGraph.activate(this.fCellOutput, add));
        this.hiddenContext = elmul;
        this.cellContext = add;
        return elmul;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix[] getLayerParameters() {
        if (this.params == null) {
            this.params = new ExMatrix[]{this.wix, this.wih, this.bi, this.wfx, this.wfh, this.bf, this.wox, this.woh, this.bo, this.wcx, this.wch, this.bc};
        }
        return this.params;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public void resetState() {
        this.hiddenContext.getW().fill(0.0d);
        this.hiddenContext.getDw().fill(0.0d);
        this.hiddenContext.getStepCache().fill(0.0d);
        this.cellContext.getW().fill(0.0d);
        this.cellContext.getDw().fill(0.0d);
        this.cellContext.getStepCache().fill(0.0d);
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int inputDimension() {
        return this.wix.columns();
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int outputDimension() {
        return this.wix.rows();
    }
}
