package greycatMLTest.neuralnet;

import greycat.Callback;
import greycat.Graph;
import greycat.GraphBuilder;
import greycat.ml.neuralnet.layer.Layer;
import greycat.ml.neuralnet.layer.Layers;
import greycat.ml.neuralnet.loss.Loss;
import greycat.ml.neuralnet.loss.Losses;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.ml.neuralnet.process.ProcessGraph;
import greycat.struct.DMatrix;
import greycat.struct.EGraph;
import greycat.struct.ENode;
import greycat.struct.matrix.VolatileDMatrix;
import org.junit.Test;

/* loaded from: input_file:greycatMLTest/neuralnet/TestFeedForward.class */
public class TestFeedForward {
    private static double EPS = 1.0E-16d;

    @Test
    public void testcalc() {
        final Graph build = GraphBuilder.newBuilder().build();
        build.connect(new Callback<Boolean>() { // from class: greycatMLTest.neuralnet.TestFeedForward.1
            public void on(Boolean bool) {
                EGraph eGraph = (EGraph) build.newNode(0L, 0L).getOrCreate("nn", (byte) 17);
                ENode newNode = eGraph.newNode();
                ENode newNode2 = eGraph.newNode();
                newNode.set("activation", (byte) 4, 1);
                newNode2.set("activation", (byte) 4, 1);
                newNode.set("type", (byte) 4, 0);
                newNode2.set("type", (byte) 4, 0);
                VolatileDMatrix empty = VolatileDMatrix.empty(2, 1);
                empty.set(0, 0, 0.05d);
                empty.set(1, 0, 0.1d);
                VolatileDMatrix empty2 = VolatileDMatrix.empty(2, 1);
                empty2.set(0, 0, 0.01d);
                empty2.set(1, 0, 0.99d);
                ExMatrix exMatrix = new ExMatrix(newNode, "weights");
                exMatrix.init(2, 2);
                exMatrix.set(0, 0, 0.15d);
                exMatrix.set(0, 1, 0.2d);
                exMatrix.set(1, 0, 0.25d);
                exMatrix.set(1, 1, 0.3d);
                ExMatrix exMatrix2 = new ExMatrix(newNode, "bias");
                exMatrix2.init(2, 1);
                exMatrix2.set(0, 0, 0.35d);
                exMatrix2.set(1, 0, 0.35d);
                ExMatrix exMatrix3 = new ExMatrix(newNode2, "weights");
                exMatrix3.init(2, 2);
                exMatrix3.set(0, 0, 0.4d);
                exMatrix3.set(0, 1, 0.45d);
                exMatrix3.set(1, 0, 0.5d);
                exMatrix3.set(1, 1, 0.55d);
                ExMatrix exMatrix4 = new ExMatrix(newNode2, "bias");
                exMatrix4.init(2, 1);
                exMatrix4.set(0, 0, 0.6d);
                exMatrix4.set(1, 0, 0.6d);
                Loss unit = Losses.getUnit(0);
                Layer loadLayer = Layers.loadLayer(newNode);
                Layer loadLayer2 = Layers.loadLayer(newNode2);
                ProcessGraph processGraph = new ProcessGraph(true);
                TestFeedForward.testdouble(Losses.sumOfLosses(processGraph.applyLoss(unit, loadLayer2.forward(loadLayer.forward(ExMatrix.createFromW(empty), processGraph), processGraph), ExMatrix.createFromW(empty2), true)), 0.2983711087600027d);
                processGraph.backpropagate();
                ExMatrix[] layerParameters = loadLayer.getLayerParameters();
                ExMatrix[] layerParameters2 = loadLayer2.getLayerParameters();
                for (ExMatrix exMatrix5 : layerParameters) {
                    TestFeedForward.applyLearningRate(exMatrix5, 0.5d);
                }
                for (ExMatrix exMatrix6 : layerParameters2) {
                    TestFeedForward.applyLearningRate(exMatrix6, 0.5d);
                }
                ExMatrix exMatrix7 = layerParameters[0];
                ExMatrix exMatrix8 = layerParameters[1];
                ExMatrix exMatrix9 = layerParameters2[0];
                ExMatrix exMatrix10 = layerParameters2[1];
                TestFeedForward.testdouble(exMatrix7.get(0, 0), 0.1497807161327628d);
                TestFeedForward.testdouble(exMatrix7.get(0, 1), 0.19956143226552567d);
                TestFeedForward.testdouble(exMatrix7.get(1, 0), 0.24975114363236958d);
                TestFeedForward.testdouble(exMatrix7.get(1, 1), 0.29950228726473915d);
                TestFeedForward.testdouble(exMatrix8.get(0, 0), 0.3456143226552565d);
                TestFeedForward.testdouble(exMatrix8.get(1, 0), 0.3450228726473914d);
                TestFeedForward.testdouble(exMatrix9.get(0, 0), 0.35891647971788465d);
                TestFeedForward.testdouble(exMatrix9.get(0, 1), 0.4086661860762334d);
                TestFeedForward.testdouble(exMatrix9.get(1, 0), 0.5113012702387375d);
                TestFeedForward.testdouble(exMatrix9.get(1, 1), 0.5613701211079891d);
                TestFeedForward.testdouble(exMatrix10.get(0, 0), 0.5307507191857215d);
                TestFeedForward.testdouble(exMatrix10.get(1, 0), 0.6190491182582781d);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void testdouble(double d, double d2) {
        if (Math.abs(d - d2) > EPS) {
            System.out.println("d1: " + d + " d2: " + d2);
            throw new RuntimeException("d1 != d2");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void applyLearningRate(ExMatrix exMatrix, double d) {
        int length = exMatrix.length();
        DMatrix dw = exMatrix.getDw();
        for (int i = 0; i < length; i++) {
            exMatrix.unsafeSet(i, exMatrix.unsafeGet(i) - (d * dw.unsafeGet(i)));
        }
        dw.fill(0.0d);
    }
}
