package greycat.ml.neuralnet.layer;

import greycat.ml.neuralnet.process.ExMatrix;
import greycat.ml.neuralnet.process.ProcessGraph;
import greycat.struct.ENode;
import greycat.struct.matrix.MatrixOps;
import greycat.struct.matrix.RandomGenerator;

/* loaded from: input_file:greycat/ml/neuralnet/layer/Linear.class */
class Linear implements Layer {
    private static String WEIGHTS = "weights";
    private ENode host;
    private ExMatrix weights;
    private ExMatrix[] params = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    public Linear(ENode eNode) {
        if (eNode == null) {
            throw new RuntimeException("Host node can't be null");
        }
        this.weights = new ExMatrix(eNode, WEIGHTS);
        this.host = eNode;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public Layer init(int i, int i2, int i3, double[] dArr, RandomGenerator randomGenerator, double d) {
        this.host.set(Layers.TYPE, (byte) 4, 1);
        this.weights.init(i2, i);
        return reInit(randomGenerator, d);
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public Layer reInit(RandomGenerator randomGenerator, double d) {
        if (randomGenerator != null && d != 0.0d) {
            MatrixOps.fillWithRandomStd(this.weights, randomGenerator, d);
        }
        return this;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix forward(ExMatrix exMatrix, ProcessGraph processGraph) {
        return processGraph.mul(this.weights, exMatrix);
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix[] getLayerParameters() {
        if (this.params == null) {
            this.params = new ExMatrix[]{this.weights};
        }
        return this.params;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public void resetState() {
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int inputDimensions() {
        return this.weights.columns();
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int outputDimensions() {
        return this.weights.rows();
    }
}
