package greycat.ml.neuralnet.process;

import greycat.ml.neuralnet.activation.Activation;
import greycat.ml.neuralnet.loss.Loss;
import greycat.struct.DMatrix;
import greycat.struct.matrix.MatrixOps;
import greycat.struct.matrix.TransposeType;
import greycat.struct.matrix.VolatileDMatrix;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:greycat/ml/neuralnet/process/ProcessGraph.class */
public class ProcessGraph {
    private boolean applyBackprop;
    private List<ProcessStep> backprop = new ArrayList();

    public ProcessGraph(boolean z) {
        this.applyBackprop = z;
    }

    public final void backpropagate() {
        for (int size = this.backprop.size() - 1; size >= 0; size--) {
            this.backprop.get(size).execute();
        }
        this.backprop.clear();
    }

    public final void setBackPropagation(boolean z) {
        this.applyBackprop = z;
        this.backprop.clear();
    }

    public final ExMatrix mul(final ExMatrix exMatrix, final ExMatrix exMatrix2) {
        final ExMatrix createFromW = ExMatrix.createFromW(MatrixOps.multiply(exMatrix, exMatrix2));
        if (this.applyBackprop) {
            this.backprop.add(new ProcessStep() { // from class: greycat.ml.neuralnet.process.ProcessGraph.1
                @Override // greycat.ml.neuralnet.process.ProcessStep
                public void execute() {
                    DMatrix multiplyTranspose = MatrixOps.multiplyTranspose(TransposeType.NOTRANSPOSE, createFromW.getDw(), TransposeType.TRANSPOSE, exMatrix2.getW());
                    DMatrix multiplyTranspose2 = MatrixOps.multiplyTranspose(TransposeType.TRANSPOSE, exMatrix.getW(), TransposeType.NOTRANSPOSE, createFromW.getDw());
                    MatrixOps.addtoMatrix(exMatrix.getDw(), multiplyTranspose);
                    MatrixOps.addtoMatrix(exMatrix2.getDw(), multiplyTranspose2);
                }
            });
        }
        return createFromW;
    }

    public final ExMatrix expand(ExMatrix exMatrix, int i) {
        if (i == 1) {
            return exMatrix;
        }
        if (exMatrix.columns() != 1) {
            throw new RuntimeException("This method does not support expansion for matrices with more than 1 column! ");
        }
        VolatileDMatrix empty = VolatileDMatrix.empty(1, i);
        empty.fill(1.0d);
        return mul(exMatrix, ExMatrix.createFromW(empty));
    }

    public final ExMatrix add(final ExMatrix exMatrix, final ExMatrix exMatrix2) {
        final ExMatrix createFromW = ExMatrix.createFromW(MatrixOps.add(exMatrix, exMatrix2));
        if (this.applyBackprop) {
            this.backprop.add(new ProcessStep() { // from class: greycat.ml.neuralnet.process.ProcessGraph.2
                @Override // greycat.ml.neuralnet.process.ProcessStep
                public void execute() {
                    MatrixOps.addtoMatrix(exMatrix.getDw(), createFromW.getDw());
                    MatrixOps.addtoMatrix(exMatrix2.getDw(), createFromW.getDw());
                }
            });
        }
        return createFromW;
    }

    public final ExMatrix activate(final Activation activation, final ExMatrix exMatrix) {
        final ExMatrix empty = ExMatrix.empty(exMatrix.rows(), exMatrix.columns());
        final int length = exMatrix.length();
        for (int i = 0; i < length; i++) {
            empty.unsafeSet(i, activation.forward(exMatrix.unsafeGet(i)));
        }
        if (this.applyBackprop) {
            this.backprop.add(new ProcessStep() { // from class: greycat.ml.neuralnet.process.ProcessGraph.3
                @Override // greycat.ml.neuralnet.process.ProcessStep
                public void execute() {
                    DMatrix dw = exMatrix.getDw();
                    DMatrix w = exMatrix.getW();
                    DMatrix dw2 = empty.getDw();
                    DMatrix w2 = empty.getW();
                    for (int i2 = 0; i2 < length; i2++) {
                        dw.unsafeSet(i2, dw.unsafeGet(i2) + (activation.backward(w.unsafeGet(i2), w2.unsafeGet(i2)) * dw2.unsafeGet(i2)));
                    }
                }
            });
        }
        return empty;
    }

    public final DMatrix applyLoss(final Loss loss, final ExMatrix exMatrix, final ExMatrix exMatrix2, boolean z) {
        if (this.applyBackprop) {
            this.backprop.add(new ProcessStep() { // from class: greycat.ml.neuralnet.process.ProcessGraph.4
                @Override // greycat.ml.neuralnet.process.ProcessStep
                public void execute() {
                    loss.backward(exMatrix, exMatrix2);
                }
            });
        }
        if (z) {
            return loss.forward(exMatrix, exMatrix2);
        }
        return null;
    }

    public ExMatrix elmul(final ExMatrix exMatrix, final ExMatrix exMatrix2) {
        final ExMatrix createFromW = ExMatrix.createFromW(MatrixOps.HadamardMult(exMatrix, exMatrix2));
        if (this.applyBackprop) {
            this.backprop.add(new ProcessStep() { // from class: greycat.ml.neuralnet.process.ProcessGraph.5
                @Override // greycat.ml.neuralnet.process.ProcessStep
                public void execute() {
                    MatrixOps.addtoMatrix(exMatrix.getDw(), MatrixOps.HadamardMult(exMatrix2.getW(), createFromW.getDw()));
                    MatrixOps.addtoMatrix(exMatrix2.getDw(), MatrixOps.HadamardMult(exMatrix.getW(), createFromW.getDw()));
                }
            });
        }
        return createFromW;
    }

    public ExMatrix oneMinus(final ExMatrix exMatrix) {
        final ExMatrix exMatrix2 = new ExMatrix(null, null);
        exMatrix2.init(exMatrix.rows(), exMatrix.columns());
        exMatrix.length();
        for (int i = 0; i < exMatrix.length(); i++) {
            exMatrix2.unsafeSet(i, 1.0d - exMatrix.unsafeGet(i));
        }
        if (!this.applyBackprop) {
            return null;
        }
        this.backprop.add(new ProcessStep() { // from class: greycat.ml.neuralnet.process.ProcessGraph.6
            @Override // greycat.ml.neuralnet.process.ProcessStep
            public void execute() {
                MatrixOps.scaleThenAddtoMatrix(exMatrix.getDw(), exMatrix2.getDw(), -1.0d);
            }
        });
        return null;
    }

    public ExMatrix concatVectors(final ExMatrix exMatrix, final ExMatrix exMatrix2) {
        if (exMatrix.columns() != exMatrix2.columns()) {
            throw new RuntimeException("Expected same column size");
        }
        final ExMatrix exMatrix3 = new ExMatrix(null, null);
        exMatrix3.init(exMatrix.rows() + exMatrix2.rows(), exMatrix.columns());
        if (exMatrix.hasStepCache() || exMatrix2.hasStepCache()) {
            DMatrix w = exMatrix3.getW();
            DMatrix dw = exMatrix3.getDw();
            DMatrix stepCache = exMatrix3.getStepCache();
            DMatrix w2 = exMatrix.getW();
            DMatrix dw2 = exMatrix.getDw();
            DMatrix stepCache2 = exMatrix.getStepCache();
            DMatrix w3 = exMatrix2.getW();
            DMatrix dw3 = exMatrix2.getDw();
            DMatrix stepCache3 = exMatrix2.getStepCache();
            for (int i = 0; i < exMatrix.rows(); i++) {
                for (int i2 = 0; i2 < exMatrix.columns(); i2++) {
                    w.set(i, i2, w2.get(i, i2));
                    dw.set(i, i2, dw2.get(i, i2));
                    stepCache.set(i, i2, stepCache2.get(i, i2));
                }
            }
            int rows = exMatrix.rows();
            for (int i3 = 0; i3 < exMatrix2.rows(); i3++) {
                for (int i4 = 0; i4 < exMatrix2.columns(); i4++) {
                    w.set(i3 + rows, i4, w3.get(i3, i4));
                    dw.set(i3 + rows, i4, dw3.get(i3, i4));
                    stepCache.set(i3 + rows, i4, stepCache3.get(i3, i4));
                }
            }
        } else {
            DMatrix w4 = exMatrix3.getW();
            DMatrix dw4 = exMatrix3.getDw();
            DMatrix w5 = exMatrix.getW();
            DMatrix dw5 = exMatrix.getDw();
            DMatrix w6 = exMatrix2.getW();
            DMatrix dw6 = exMatrix2.getDw();
            for (int i5 = 0; i5 < exMatrix.rows(); i5++) {
                for (int i6 = 0; i6 < exMatrix.columns(); i6++) {
                    w4.set(i5, i6, w5.get(i5, i6));
                    dw4.set(i5, i6, dw5.get(i5, i6));
                }
            }
            int rows2 = exMatrix.rows();
            for (int i7 = 0; i7 < exMatrix2.rows(); i7++) {
                for (int i8 = 0; i8 < exMatrix2.columns(); i8++) {
                    w4.set(i7 + rows2, i8, w6.get(i7, i8));
                    dw4.set(i7 + rows2, i8, dw6.get(i7, i8));
                }
            }
        }
        if (this.applyBackprop) {
            this.backprop.add(new ProcessStep() { // from class: greycat.ml.neuralnet.process.ProcessGraph.7
                @Override // greycat.ml.neuralnet.process.ProcessStep
                public void execute() {
                    DMatrix dw7 = exMatrix3.getDw();
                    DMatrix dw8 = exMatrix.getDw();
                    DMatrix dw9 = exMatrix2.getDw();
                    for (int i9 = 0; i9 < exMatrix.rows(); i9++) {
                        for (int i10 = 0; i10 < exMatrix.columns(); i10++) {
                            dw8.set(i9, i10, dw7.get(i9, i10));
                        }
                    }
                    int rows3 = exMatrix.rows();
                    for (int i11 = 0; i11 < exMatrix2.rows(); i11++) {
                        for (int i12 = 0; i12 < exMatrix2.columns(); i12++) {
                            dw9.set(i11, i12, dw7.get(i11 + rows3, i12));
                        }
                    }
                }
            });
        }
        return exMatrix3;
    }
}
