package greycat.ml.neuralnet;

import greycat.ml.common.matrix.VolatileDMatrix;
import greycat.ml.neuralnet.layer.Layer;
import greycat.ml.neuralnet.layer.Layers;
import greycat.ml.neuralnet.learner.Learner;
import greycat.ml.neuralnet.learner.Learners;
import greycat.ml.neuralnet.loss.Loss;
import greycat.ml.neuralnet.loss.Losses;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.ml.neuralnet.process.ProcessGraph;
import greycat.struct.EGraph;
import greycat.struct.ENode;
import java.util.Random;

/* loaded from: input_file:greycat/ml/neuralnet/NeuralNet.class */
public class NeuralNet {
    private static final String TRAIN_LOSS = "train_loss";
    private static final String REPORTING_LOSS = "reporting_loss";
    private static final String LEARNER = "learner";
    private static final String SEED = "seed";
    private static final String STD = "std";
    private static final double STD_DEF = 0.08d;
    private EGraph backend;
    private ENode root;
    private Layer[] layers;
    private Loss tarinLoss;
    private Loss reportingLoss;
    private Learner learner;
    private Random random;
    private double std;

    public NeuralNet(EGraph eGraph) {
        this.backend = eGraph;
        int size = this.backend.size() - 1;
        if (size < 0) {
            this.root = eGraph.newNode();
            eGraph.setRoot(this.root);
            size = 0;
        } else {
            this.root = eGraph.root();
        }
        this.tarinLoss = Losses.getUnit(((Integer) this.root.getWithDefault(TRAIN_LOSS, 0)).intValue());
        this.reportingLoss = Losses.getUnit(((Integer) this.root.getWithDefault(REPORTING_LOSS, 0)).intValue());
        this.learner = Learners.getUnit(((Integer) this.root.getWithDefault(LEARNER, 0)).intValue(), this.backend.root());
        this.random = new Random();
        this.random.setSeed(((Long) this.root.getWithDefault(SEED, Long.valueOf(System.currentTimeMillis()))).longValue());
        this.std = ((Double) this.root.getWithDefault(STD, Double.valueOf(STD_DEF))).doubleValue();
        if (size <= 0) {
            this.layers = new Layer[0];
            return;
        }
        this.layers = new Layer[size];
        for (int i = 0; i < this.layers.length; i++) {
            this.layers[i] = Layers.loadLayer(this.backend.node(i));
        }
    }

    public void setTrainLoss(int i) {
        this.tarinLoss = Losses.getUnit(i);
        this.root.set(TRAIN_LOSS, (byte) 4, Integer.valueOf(i));
    }

    public void setReportingLoss(int i) {
        this.reportingLoss = Losses.getUnit(i);
        this.root.set(REPORTING_LOSS, (byte) 4, Integer.valueOf(i));
    }

    public void setLearner(int i, double[] dArr, int i2) {
        this.learner = Learners.getUnit(i, this.root);
        if (dArr != null) {
            this.learner.setParams(dArr);
        }
        this.learner.setFrequency(i2);
    }

    public void setRandom(long j, double d) {
        this.random.setSeed(j);
        this.std = d;
        this.root.set(SEED, (byte) 3, Long.valueOf(j));
        this.root.set(STD, (byte) 5, Double.valueOf(d));
    }

    public NeuralNet addLayer(int i, int i2, int i3, int i4, double[] dArr) {
        if (this.layers.length > 0 && this.layers[this.layers.length - 1].outputDimension() != i2) {
            throw new RuntimeException("Layers last output size is different that current layer input");
        }
        Layer createLayer = Layers.createLayer(this.backend.newNode(), i);
        createLayer.init(i2, i3, i4, dArr, this.random, this.std);
        internal_add(createLayer);
        return this;
    }

    public double learn(double[] dArr, double[] dArr2) {
        ProcessGraph processGraph = new ProcessGraph(true);
        ExMatrix createFromW = ExMatrix.createFromW(VolatileDMatrix.wrap(dArr, dArr.length, 1));
        ExMatrix createFromW2 = ExMatrix.createFromW(VolatileDMatrix.wrap(dArr2, dArr2.length, 1));
        double applyLoss = processGraph.applyLoss(this.tarinLoss, internalForward(processGraph, createFromW), createFromW2);
        processGraph.backpropagate();
        this.learner.stepUpdate(this.layers);
        return applyLoss;
    }

    public final void finalLearn() {
        this.learner.finalUpdate(this.layers);
    }

    public void resetState() {
        for (int i = 0; i < this.layers.length; i++) {
            this.layers[i].resetState();
        }
    }

    public double[] predict(double[] dArr) {
        return internalForward(new ProcessGraph(false), ExMatrix.createFromW(VolatileDMatrix.wrap(dArr, dArr.length, 1))).data();
    }

    private ExMatrix internalForward(ProcessGraph processGraph, ExMatrix exMatrix) {
        ExMatrix exMatrix2 = exMatrix;
        for (int i = 0; i < this.layers.length; i++) {
            exMatrix2 = this.layers[i].forward(exMatrix2, processGraph);
        }
        return exMatrix2;
    }

    private void internal_add(Layer layer) {
        Layer[] layerArr = new Layer[this.layers.length + 1];
        System.arraycopy(this.layers, 0, layerArr, 0, this.layers.length);
        layerArr[this.layers.length] = layer;
        this.layers = layerArr;
    }
}
