package greycatMLTest.neuralnet;

import greycat.ml.neuralnet.activation.Activation;
import greycat.ml.neuralnet.activation.Activations;
import greycat.ml.neuralnet.loss.Loss;
import greycat.ml.neuralnet.loss.Losses;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.struct.DMatrix;
import greycat.struct.matrix.MatrixOps;
import greycat.struct.matrix.TransposeType;
import greycat.struct.matrix.VolatileDMatrix;

/* loaded from: input_file:greycatMLTest/neuralnet/TestCalc.class */
public class TestCalc {
    public static void main(String[] strArr) {
        VolatileDMatrix empty = VolatileDMatrix.empty(2, 1);
        empty.set(0, 0, 0.05d);
        empty.set(1, 0, 0.1d);
        ExMatrix createFromW = ExMatrix.createFromW(empty);
        VolatileDMatrix empty2 = VolatileDMatrix.empty(2, 1);
        empty2.set(0, 0, 0.01d);
        empty2.set(1, 0, 0.99d);
        ExMatrix createFromW2 = ExMatrix.createFromW(empty2);
        VolatileDMatrix empty3 = VolatileDMatrix.empty(2, 2);
        empty3.set(0, 0, 0.15d);
        empty3.set(0, 1, 0.2d);
        empty3.set(1, 0, 0.25d);
        empty3.set(1, 1, 0.3d);
        ExMatrix createFromW3 = ExMatrix.createFromW(empty3);
        VolatileDMatrix empty4 = VolatileDMatrix.empty(2, 1);
        empty4.set(0, 0, 0.35d);
        empty4.set(1, 0, 0.35d);
        ExMatrix createFromW4 = ExMatrix.createFromW(empty4);
        VolatileDMatrix empty5 = VolatileDMatrix.empty(2, 2);
        empty5.set(0, 0, 0.4d);
        empty5.set(0, 1, 0.45d);
        empty5.set(1, 0, 0.5d);
        empty5.set(1, 1, 0.55d);
        ExMatrix createFromW5 = ExMatrix.createFromW(empty5);
        VolatileDMatrix empty6 = VolatileDMatrix.empty(2, 1);
        empty6.set(0, 0, 0.6d);
        empty6.set(1, 0, 0.6d);
        ExMatrix createFromW6 = ExMatrix.createFromW(empty6);
        Loss unit = Losses.getUnit(0);
        Activation unit2 = Activations.getUnit(1, (double[]) null);
        ExMatrix mul = mul(createFromW3, createFromW);
        ExMatrix add = add(mul, createFromW4);
        ExMatrix activation = activation(unit2, add);
        ExMatrix mul2 = mul(createFromW5, activation);
        ExMatrix add2 = add(mul2, createFromW6);
        ExMatrix activation2 = activation(unit2, add2);
        DMatrix forward = unit.forward(activation2, createFromW2);
        System.out.println("");
        System.out.println("Step 1, before activation: [" + add.get(0, 0) + " , " + add.get(1, 0) + "]");
        System.out.println("Step 2, after activation: [" + activation.get(0, 0) + " , " + activation.get(1, 0) + "]");
        System.out.println("Step 3, before activation: [" + add2.get(0, 0) + " , " + add2.get(1, 0) + "]");
        System.out.println("Step 4, after activation, actual output: [" + activation2.get(0, 0) + " , " + activation2.get(1, 0) + "]");
        System.out.println("Total Error: " + forward);
        System.out.println("");
        unit.backward(activation2, createFromW2);
        backpropActivation(unit2, add2, activation2);
        backpropAdd(mul2, createFromW6, add2);
        backpropMult(createFromW5, activation, mul2);
        backpropActivation(unit2, add, activation);
        backpropAdd(mul, createFromW4, add);
        backpropMult(createFromW3, createFromW, mul);
        applyLearningRate(createFromW3, 0.5d);
        applyLearningRate(createFromW5, 0.5d);
        applyLearningRate(createFromW4, 0.5d);
        applyLearningRate(createFromW6, 0.5d);
        System.out.println("After learning: ");
        System.out.println("w1: " + createFromW3.get(0, 0));
        System.out.println("w2: " + createFromW3.get(0, 1));
        System.out.println("w3: " + createFromW3.get(1, 0));
        System.out.println("w4: " + createFromW3.get(1, 1));
        System.out.println("Bias: ");
        System.out.println("b1-a: " + createFromW4.get(0, 0));
        System.out.println("b1-b: " + createFromW4.get(1, 0));
        System.out.println("");
        System.out.println("w5: " + createFromW5.get(0, 0));
        System.out.println("w6: " + createFromW5.get(0, 1));
        System.out.println("w7: " + createFromW5.get(1, 0));
        System.out.println("w8: " + createFromW5.get(1, 1));
        System.out.println("Bias: ");
        System.out.println("b2-a: " + createFromW6.get(0, 0));
        System.out.println("b2-b: " + createFromW6.get(1, 0));
    }

    private static void applyLearningRate(ExMatrix exMatrix, double d) {
        int length = exMatrix.length();
        DMatrix dw = exMatrix.getDw();
        for (int i = 0; i < length; i++) {
            exMatrix.unsafeSet(i, exMatrix.unsafeGet(i) - (d * dw.unsafeGet(i)));
        }
        dw.fill(0.0d);
    }

    private static ExMatrix mul(ExMatrix exMatrix, ExMatrix exMatrix2) {
        return ExMatrix.createFromW(MatrixOps.multiply(exMatrix, exMatrix2));
    }

    private static ExMatrix add(ExMatrix exMatrix, ExMatrix exMatrix2) {
        return ExMatrix.createFromW(MatrixOps.add(exMatrix, exMatrix2));
    }

    private static void backpropAdd(ExMatrix exMatrix, ExMatrix exMatrix2, ExMatrix exMatrix3) {
        MatrixOps.addtoMatrix(exMatrix.getDw(), exMatrix3.getDw());
        MatrixOps.addtoMatrix(exMatrix2.getDw(), exMatrix3.getDw());
    }

    private static void backpropMult(ExMatrix exMatrix, ExMatrix exMatrix2, ExMatrix exMatrix3) {
        DMatrix multiplyTranspose = MatrixOps.multiplyTranspose(TransposeType.NOTRANSPOSE, exMatrix3.getDw(), TransposeType.TRANSPOSE, exMatrix2.getW());
        DMatrix multiplyTranspose2 = MatrixOps.multiplyTranspose(TransposeType.TRANSPOSE, exMatrix.getW(), TransposeType.NOTRANSPOSE, exMatrix3.getDw());
        MatrixOps.addtoMatrix(exMatrix.getDw(), multiplyTranspose);
        MatrixOps.addtoMatrix(exMatrix2.getDw(), multiplyTranspose2);
    }

    private static ExMatrix activation(Activation activation, ExMatrix exMatrix) {
        ExMatrix empty = ExMatrix.empty(exMatrix.rows(), exMatrix.columns());
        int length = exMatrix.length();
        for (int i = 0; i < length; i++) {
            empty.unsafeSet(i, activation.forward(exMatrix.unsafeGet(i)));
        }
        return empty;
    }

    public static void backpropActivation(Activation activation, ExMatrix exMatrix, ExMatrix exMatrix2) {
        DMatrix dw = exMatrix.getDw();
        DMatrix w = exMatrix.getW();
        DMatrix dw2 = exMatrix2.getDw();
        DMatrix w2 = exMatrix2.getW();
        int length = exMatrix.length();
        for (int i = 0; i < length; i++) {
            dw.unsafeSet(i, dw.unsafeGet(i) + (activation.backward(w.unsafeGet(i), w2.unsafeGet(i)) * dw2.unsafeGet(i)));
        }
    }
}
