package org.wlld.transFormer;

import java.util.ArrayList;
import java.util.List;
import org.wlld.config.TfConfig;
import org.wlld.matrixTools.Matrix;
import org.wlld.transFormer.model.CodecBlockModel;
import org.wlld.transFormer.model.TransFormerModel;
import org.wlld.transFormer.nerve.SensoryNerve;

/* loaded from: input_file:org/wlld/transFormer/TransFormerManager.class */
public class TransFormerManager {
    private final SensoryNerve sensoryNerve;
    private final List<CodecBlock> encoderBlocks = new ArrayList();
    private final List<CodecBlock> decoderBlocks = new ArrayList();
    private final FirstDecoderBlock firstDecoderBlock;
    private final LineBlock lineBlock;
    private final int maxLength;
    private final boolean selfTimeCode;

    public SensoryNerve getSensoryNerve() {
        return this.sensoryNerve;
    }

    public TransFormerModel getModel() {
        TransFormerModel transFormerModel = new TransFormerModel();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < this.encoderBlocks.size(); i++) {
            arrayList.add(this.encoderBlocks.get(i).getModel());
            arrayList2.add(this.decoderBlocks.get(i).getModel());
        }
        transFormerModel.setEncoderBlockModels(arrayList);
        transFormerModel.setDecoderBlockModels(arrayList2);
        transFormerModel.setFirstDecoderBlockModel(this.firstDecoderBlock.getModel());
        transFormerModel.setLineBlockModel(this.lineBlock.getModel());
        return transFormerModel;
    }

    private Matrix addTimeCode(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix matrix2 = new Matrix(x, y);
        for (int i = 0; i < x; i++) {
            for (int i2 = 0; i2 < y; i2++) {
                double pow = 1.0d / Math.pow(10000.0d, (2.0d * (i2 / 2)) / y);
                matrix2.setNub(i, i2, matrix.getNumber(i, i2) + (i2 % 2 == 0 ? Math.sin(pow * i) : Math.cos(pow * i)));
            }
        }
        return matrix2;
    }

    private Matrix addTimeCodeBySelf(Matrix matrix) throws Exception {
        double d = 1.0d / this.maxLength;
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix matrix2 = new Matrix(x, y);
        for (int i = 1; i < x; i++) {
            double d2 = i * d;
            for (int i2 = 0; i2 < y; i2++) {
                matrix2.setNub(i, i2, matrix.getNumber(i, i2) + d2);
            }
        }
        return matrix2;
    }

    public Matrix getStartMatrix(Matrix matrix) throws Exception {
        Matrix addTimeCodeBySelf = this.selfTimeCode ? addTimeCodeBySelf(matrix) : addTimeCode(matrix);
        Matrix matrix2 = new Matrix(1, addTimeCodeBySelf.getY());
        for (int i = 0; i < addTimeCodeBySelf.getY(); i++) {
            matrix2.setNub(0, i, addTimeCodeBySelf.getColumn(i).getAVG());
        }
        return matrix2;
    }

    public void insertModel(TransFormerModel transFormerModel) throws Exception {
        List<CodecBlockModel> encoderBlockModels = transFormerModel.getEncoderBlockModels();
        List<CodecBlockModel> decoderBlockModels = transFormerModel.getDecoderBlockModels();
        for (int i = 0; i < this.encoderBlocks.size(); i++) {
            this.encoderBlocks.get(i).insertModel(encoderBlockModels.get(i));
            this.decoderBlocks.get(i).insertModel(decoderBlockModels.get(i));
        }
        this.firstDecoderBlock.insertModel(transFormerModel.getFirstDecoderBlockModel());
        this.lineBlock.insertModel(transFormerModel.getLineBlockModel());
    }

    public TransFormerManager(TfConfig tfConfig) throws Exception {
        int multiNumber = tfConfig.getMultiNumber();
        this.maxLength = tfConfig.getMaxLength();
        this.selfTimeCode = tfConfig.isSelfTimeCode();
        int featureDimension = tfConfig.getFeatureDimension();
        if (featureDimension % 2 != 0) {
            throw new Exception("TransFormer 词向量维度必须为偶数");
        }
        int allDepth = tfConfig.getAllDepth();
        double studyPoint = tfConfig.getStudyPoint();
        int typeNumber = tfConfig.getTypeNumber();
        boolean isShowLog = tfConfig.isShowLog();
        int regularModel = tfConfig.getRegularModel();
        double regular = tfConfig.getRegular();
        if (multiNumber <= 1 || featureDimension <= 0 || allDepth <= 0 || typeNumber <= 1) {
            throw new Exception("param is null,typeNumber:" + typeNumber + ",featureDimension:" + featureDimension);
        }
        for (int i = 0; i < allDepth; i++) {
            this.encoderBlocks.add(new CodecBlock(multiNumber, featureDimension, studyPoint, i + 1, true, regularModel, regular, this.maxLength, this.selfTimeCode, tfConfig.getCoreNumber()));
        }
        CodecBlock codecBlock = this.encoderBlocks.get(this.encoderBlocks.size() - 1);
        for (int i2 = 0; i2 < allDepth; i2++) {
            CodecBlock codecBlock2 = new CodecBlock(multiNumber, featureDimension, studyPoint, i2 + 2, false, regularModel, regular, this.maxLength, this.selfTimeCode, tfConfig.getCoreNumber());
            codecBlock2.setLastEncoderBlock(codecBlock);
            this.decoderBlocks.add(codecBlock2);
        }
        CodecBlock codecBlock3 = this.decoderBlocks.get(this.decoderBlocks.size() - 1);
        connectCodecBlock(this.encoderBlocks);
        connectCodecBlock(this.decoderBlocks);
        this.lineBlock = new LineBlock(typeNumber, featureDimension, studyPoint, codecBlock3, isShowLog, regularModel, regular, tfConfig.getCoreNumber());
        codecBlock3.setLineBlock(this.lineBlock);
        this.firstDecoderBlock = new FirstDecoderBlock(multiNumber, featureDimension, studyPoint, this.decoderBlocks.get(0), this.maxLength, this.selfTimeCode, tfConfig.getCoreNumber());
        this.firstDecoderBlock.setLastEncoderBlock(codecBlock);
        this.decoderBlocks.get(0).setFirstDecoderBlock(this.firstDecoderBlock);
        this.sensoryNerve = new SensoryNerve(this.encoderBlocks.get(0), this.firstDecoderBlock);
    }

    private void connectCodecBlock(List<CodecBlock> list) {
        int size = list.size();
        for (int i = 0; i < size - 1; i++) {
            CodecBlock codecBlock = list.get(i);
            CodecBlock codecBlock2 = list.get(i + 1);
            codecBlock.setBeforeEncoderBlock(codecBlock2);
            codecBlock2.setAfterEncoderBlock(codecBlock);
        }
    }
}
