package org.wlld.rnnJumpNerveEntity;

import java.util.Map;
import org.wlld.i.ActiveFunction;
import org.wlld.i.OutBack;
import org.wlld.matrixTools.Matrix;

/* loaded from: input_file:org/wlld/rnnJumpNerveEntity/OutNerve.class */
public class OutNerve extends Nerve {
    private Map<Integer, Matrix> matrixMapE;
    private final boolean isShowLog;
    private final boolean isSoftMax;

    public OutNerve(int i, double d, boolean z, ActiveFunction activeFunction, boolean z2, boolean z3, int i2, double d2, boolean z4, int i3, int i4, int i5, int i6, int i7, int i8) throws Exception {
        super(i, "OutNerve", d, z, activeFunction, z2, i2, d2, i3, i4, i5, i6, i7, i8, false, 0);
        this.isShowLog = z3;
        this.isSoftMax = z4;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void getGBySoftMax(double d, long j, int[] iArr, int i) throws Exception {
        this.gradient = d;
        updatePower(j, iArr, i);
    }

    public void setMatrixMap(Map<Integer, Matrix> map) {
        this.matrixMapE = map;
    }

    @Override // org.wlld.rnnJumpNerveEntity.Nerve
    protected void sendAppointTestMessage(long j, double d, Matrix matrix, OutBack outBack, String str) throws Exception {
        if (insertParameter(j, d)) {
            double calculation = calculation(j);
            destroyParameter(j);
            sendSoftMaxBack(j, calculation, matrix, outBack, str);
        }
    }

    public void backMatrixError(double d, long j, int[] iArr, int i) throws Exception {
        this.gradient = this.activeFunction.functionG(this.outNub) * d;
        updatePower(j, iArr, i);
    }

    @Override // org.wlld.rnnJumpNerveEntity.Nerve
    public void input(long j, double d, boolean z, Map<Integer, Double> map, OutBack outBack, Matrix matrix, int[] iArr, int i, int i2) throws Exception {
        if (insertParameter(j, d)) {
            double calculation = calculation(j);
            if (this.isSoftMax) {
                if (!z) {
                    destroyParameter(j);
                }
                sendSoftMax(j, calculation, z, map, outBack, matrix, iArr, i);
                return;
            }
            double function = this.activeFunction.function(calculation);
            if (!z) {
                destroyParameter(j);
                if (outBack == null) {
                    throw new Exception("not find outBack");
                }
                outBack.getBack(function, getId(), j);
                return;
            }
            this.outNub = function;
            if (map.containsKey(Integer.valueOf(getId()))) {
                this.E = map.get(Integer.valueOf(getId())).doubleValue();
            } else {
                this.E = 0.0d;
            }
            if (this.isShowLog) {
                System.out.println("E==" + this.E + ",out==" + function + ",nerveId==" + getId());
            }
            this.gradient = outGradient();
            updatePower(j, iArr, i);
        }
    }

    @Override // org.wlld.rnnJumpNerveEntity.Nerve
    protected void inputMatrix(long j, Matrix matrix, boolean z, int i, OutBack outBack) throws Exception {
        Matrix conv = conv(matrix);
        if (!z) {
            if (outBack == null) {
                throw new Exception("not find outBack");
            }
            outBack.getBackMatrix(conv, getId(), j);
            return;
        }
        Matrix matrix2 = this.matrixMapE.get(Integer.valueOf(i));
        if (this.isShowLog) {
            System.out.println("E========" + i);
            System.out.println(conv.getString());
        }
        if (matrix2.getX() != conv.getX() || matrix2.getY() != conv.getY()) {
            throw new Exception("Wrong size setting of image in templateConfig");
        }
        backMatrix(getGradient(conv, matrix2));
    }

    private Matrix getGradient(Matrix matrix, Matrix matrix2) throws Exception {
        Matrix matrix3 = new Matrix(matrix.getX(), matrix.getY());
        for (int i = 0; i < matrix2.getX(); i++) {
            for (int i2 = 0; i2 < matrix2.getY(); i2++) {
                matrix3.setNub(i, i2, matrix2.getNumber(i, i2) - matrix.getNumber(i, i2));
            }
        }
        return matrix3;
    }

    private double outGradient() {
        return this.activeFunction.functionG(this.outNub) * (this.E - this.outNub);
    }
}
