package greycat.ml.neuralnet.layer;

import greycat.ml.neuralnet.activation.Activation;
import greycat.ml.neuralnet.activation.Activations;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.ml.neuralnet.process.ProcessGraph;
import greycat.struct.DMatrix;
import greycat.struct.DoubleArray;
import greycat.struct.ENode;
import greycat.struct.matrix.MatrixOps;
import greycat.struct.matrix.RandomGenerator;

/* loaded from: input_file:greycat/ml/neuralnet/layer/FeedForward.class */
class FeedForward implements Layer {
    private static String WEIGHTS = "weights";
    private static String BIAS = "bias";
    private static String ACTIVATION = "activation";
    private static String ACTIVATION_PARAM = "activation_param";
    private ExMatrix weights;
    private ExMatrix bias;
    private Activation activation;
    private ENode host;
    private ExMatrix[] params = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    public FeedForward(ENode eNode) {
        if (eNode == null) {
            throw new RuntimeException("Host node can't be null");
        }
        this.weights = new ExMatrix(eNode, WEIGHTS);
        this.bias = new ExMatrix(eNode, BIAS);
        DoubleArray doubleArray = (DoubleArray) eNode.get(ACTIVATION_PARAM);
        this.activation = Activations.getUnit(((Integer) eNode.getWithDefault(ACTIVATION, 0)).intValue(), doubleArray != null ? doubleArray.extract() : null);
        this.host = eNode;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public Layer init(int i, int i2, int i3, double[] dArr, RandomGenerator randomGenerator, double d) {
        this.host.set(Layers.TYPE, (byte) 4, 0);
        this.weights.init(i2, i);
        this.bias.init(i2, 1);
        this.activation = Activations.getUnit(i3, dArr);
        this.host.set(ACTIVATION, (byte) 4, Integer.valueOf(i3));
        if (dArr != null) {
            ((DoubleArray) this.host.getOrCreate(ACTIVATION_PARAM, (byte) 6)).initWith(dArr);
        }
        return reInit(randomGenerator, d);
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public Layer reInit(RandomGenerator randomGenerator, double d) {
        if (randomGenerator != null && d != 0.0d) {
            MatrixOps.fillWithRandomStd(this.weights, randomGenerator, d);
            MatrixOps.fillWithRandomStd(this.bias, randomGenerator, d);
        }
        return this;
    }

    public void setWeights(DMatrix dMatrix) {
        MatrixOps.copy(dMatrix, this.weights);
    }

    public void setBias(DMatrix dMatrix) {
        MatrixOps.copy(dMatrix, this.bias);
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix forward(ExMatrix exMatrix, ProcessGraph processGraph) {
        return processGraph.activate(this.activation, processGraph.add(processGraph.mul(this.weights, exMatrix), processGraph.expand(this.bias, exMatrix.columns())));
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public ExMatrix[] getLayerParameters() {
        if (this.params == null) {
            this.params = new ExMatrix[]{this.weights, this.bias};
        }
        return this.params;
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public void resetState() {
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int inputDimensions() {
        return this.weights.columns();
    }

    @Override // greycat.ml.neuralnet.layer.Layer
    public int outputDimensions() {
        return this.weights.rows();
    }
}
