package com.top.bpnn;

import com.top.matrix.Matrix;
import com.top.utils.MatrixUtil;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:com/top/bpnn/BPNeuralNetworkFactory.class */
public class BPNeuralNetworkFactory {
    public BPModel trainBP(BPParameter bPParameter, Matrix matrix) throws Exception {
        Matrix plus;
        ActivationFunction activationFunction = bPParameter.getActivationFunction();
        int inputLayerNeuronCount = bPParameter.getInputLayerNeuronCount();
        int hiddenLayerNeuronCount = bPParameter.getHiddenLayerNeuronCount();
        int outputLayerNeuronCount = bPParameter.getOutputLayerNeuronCount();
        double normalizationMin = bPParameter.getNormalizationMin();
        double normalizationMax = bPParameter.getNormalizationMax();
        double step = bPParameter.getStep();
        double momentumFactor = bPParameter.getMomentumFactor();
        double precision = bPParameter.getPrecision();
        int maxTimes = bPParameter.getMaxTimes();
        if (matrix.getMatrixColCount() != inputLayerNeuronCount + outputLayerNeuronCount) {
            throw new Exception("神经元个数不符，请修改");
        }
        Matrix initWeight = initWeight(inputLayerNeuronCount, hiddenLayerNeuronCount);
        Matrix initWeight2 = initWeight(hiddenLayerNeuronCount, outputLayerNeuronCount);
        Matrix initThreshold = initThreshold(hiddenLayerNeuronCount);
        Matrix initThreshold2 = initThreshold(outputLayerNeuronCount);
        Matrix matrix2 = new Matrix(inputLayerNeuronCount, hiddenLayerNeuronCount);
        Matrix matrix3 = new Matrix(hiddenLayerNeuronCount, outputLayerNeuronCount);
        Matrix matrix4 = new Matrix(1, hiddenLayerNeuronCount);
        Matrix matrix5 = new Matrix(1, outputLayerNeuronCount);
        Matrix subMatrix = matrix.subMatrix(0, matrix.getMatrixRowCount(), 0, inputLayerNeuronCount);
        Matrix subMatrix2 = matrix.subMatrix(0, matrix.getMatrixRowCount(), inputLayerNeuronCount, outputLayerNeuronCount);
        Map<String, Object> normalize = MatrixUtil.normalize(subMatrix, normalizationMin, normalizationMax);
        Matrix matrix6 = (Matrix) normalize.get("res");
        Map<String, Object> normalize2 = MatrixUtil.normalize(subMatrix2, normalizationMin, normalizationMax);
        Matrix matrix7 = (Matrix) normalize2.get("res");
        int i = 1;
        double d = 0.0d;
        while (true) {
            if (i >= maxTimes) {
                break;
            }
            Matrix multiple = matrix6.multiple(initWeight);
            Matrix plus2 = multiple.plus(initThreshold.extend(2, multiple.getMatrixRowCount()));
            Matrix computeValue = computeValue(plus2, activationFunction);
            Matrix multiple2 = computeValue.multiple(initWeight2);
            Matrix plus3 = multiple2.plus(initThreshold2.extend(2, multiple2.getMatrixRowCount()));
            Matrix subtract = matrix7.subtract(computeValue(plus3, activationFunction));
            d = computeE(subtract);
            if (Math.abs(d) <= precision) {
                System.out.println("满足精度");
                break;
            }
            Matrix transpose = subtract.multiple(step).pointMultiple(computeDerivative(plus3, activationFunction)).transpose().multiple(computeValue).transpose();
            Matrix multiple3 = subtract.multiple(step).transpose().multiple(computeDerivative(plus3, activationFunction));
            Matrix transpose2 = initWeight2.multiple(subtract.pointMultiple(computeDerivative(plus3, activationFunction)).transpose()).transpose();
            Matrix multiple4 = matrix6.transpose().multiple(transpose2.pointMultiple(computeDerivative(plus2, activationFunction))).multiple(step);
            Matrix multiple5 = transpose2.transpose().multiple(computeDerivative(plus2, activationFunction)).multiple(-step);
            if (i == 1) {
                initWeight = initWeight.plus(multiple4);
                initWeight2 = initWeight2.plus(transpose);
                initThreshold = initThreshold.plus(multiple5);
                plus = initThreshold2.plus(multiple3);
            } else {
                initWeight = initWeight.plus(multiple4).plus(matrix2.multiple(momentumFactor));
                initWeight2 = initWeight2.plus(transpose).plus(matrix3.multiple(momentumFactor));
                initThreshold = initThreshold.plus(multiple5).plus(matrix4.multiple(momentumFactor));
                plus = initThreshold2.plus(multiple3).plus(matrix5.multiple(momentumFactor));
            }
            initThreshold2 = plus;
            matrix2 = multiple4;
            matrix3 = transpose;
            matrix4 = multiple5;
            matrix5 = multiple3;
            i++;
        }
        BPModel bPModel = new BPModel();
        bPModel.setInputMax((Matrix) normalize.get("max"));
        bPModel.setInputMin((Matrix) normalize.get("min"));
        bPModel.setOutputMax((Matrix) normalize2.get("max"));
        bPModel.setOutputMin((Matrix) normalize2.get("min"));
        bPModel.setWeightIJ(initWeight);
        bPModel.setWeightJP(initWeight2);
        bPModel.setB1(initThreshold);
        bPModel.setB2(initThreshold2);
        bPModel.setError(d);
        bPModel.setTimes(i);
        System.out.println("循环次数：" + i + "，误差：" + d);
        return bPModel;
    }

    public Matrix computeBP(BPModel bPModel, Matrix matrix) throws Exception {
        if (matrix.getMatrixColCount() != bPModel.getBpParameter().getInputLayerNeuronCount()) {
            throw new Exception("输入矩阵纬度有误");
        }
        ActivationFunction activationFunction = bPModel.getBpParameter().getActivationFunction();
        Matrix weightIJ = bPModel.getWeightIJ();
        Matrix weightJP = bPModel.getWeightJP();
        Matrix b1 = bPModel.getB1();
        Matrix b2 = bPModel.getB2();
        double[][] dArr = new double[matrix.getMatrixRowCount()][matrix.getMatrixColCount()];
        for (int i = 0; i < matrix.getMatrixRowCount(); i++) {
            for (int i2 = 0; i2 < matrix.getMatrixColCount(); i2++) {
                dArr[i][i2] = bPModel.getBpParameter().getNormalizationMin() + (((matrix.getValOfIdx(i, i2) - bPModel.getInputMin().getValOfIdx(0, i2)) / (bPModel.getInputMax().getValOfIdx(0, i2) - bPModel.getInputMin().getValOfIdx(0, i2))) * (bPModel.getBpParameter().getNormalizationMax() - bPModel.getBpParameter().getNormalizationMin()));
            }
        }
        Matrix multiple = new Matrix(dArr).multiple(weightIJ);
        Matrix multiple2 = computeValue(multiple.plus(b1.extend(2, multiple.getMatrixRowCount())), activationFunction).multiple(weightJP);
        return MatrixUtil.inverseNormalize(computeValue(multiple2.plus(b2.extend(2, multiple2.getMatrixRowCount())), activationFunction), bPModel.getBpParameter().getNormalizationMax(), bPModel.getBpParameter().getNormalizationMin(), bPModel.getOutputMax(), bPModel.getOutputMin());
    }

    private Matrix initWeight(int i, int i2) {
        Random random = new Random();
        double[][] dArr = new double[i][i2];
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                dArr[i3][i4] = (2.0d * random.nextDouble()) - 1.0d;
            }
        }
        return new Matrix(dArr);
    }

    private Matrix initThreshold(int i) {
        Random random = new Random();
        double[][] dArr = new double[1][i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[0][i2] = (2.0d * random.nextDouble()) - 1.0d;
        }
        return new Matrix(dArr);
    }

    private Matrix computeValue(Matrix matrix, ActivationFunction activationFunction) throws Exception {
        if (matrix.getMatrix() == null) {
            throw new Exception("参数值为空");
        }
        double[][] dArr = new double[matrix.getMatrixRowCount()][matrix.getMatrixColCount()];
        for (int i = 0; i < matrix.getMatrixRowCount(); i++) {
            for (int i2 = 0; i2 < matrix.getMatrixColCount(); i2++) {
                dArr[i][i2] = activationFunction.computeValue(matrix.getValOfIdx(i, i2));
            }
        }
        return new Matrix(dArr);
    }

    private Matrix computeDerivative(Matrix matrix, ActivationFunction activationFunction) throws Exception {
        if (matrix.getMatrix() == null) {
            throw new Exception("参数值为空");
        }
        double[][] dArr = new double[matrix.getMatrixRowCount()][matrix.getMatrixColCount()];
        for (int i = 0; i < matrix.getMatrixRowCount(); i++) {
            for (int i2 = 0; i2 < matrix.getMatrixColCount(); i2++) {
                dArr[i][i2] = activationFunction.computeDerivative(matrix.getValOfIdx(i, i2));
            }
        }
        return new Matrix(dArr);
    }

    private double computeE(Matrix matrix) {
        return 0.5d * matrix.square().sumAll();
    }
}
