package org.wlld.naturalLanguage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.wlld.config.TfConfig;
import org.wlld.entity.TalkBody;
import org.wlld.matrixTools.Matrix;
import org.wlld.matrixTools.MatrixOperation;
import org.wlld.naturalLanguage.word.WordBack;
import org.wlld.naturalLanguage.word.WordEmbedding;
import org.wlld.transFormer.TransFormerManager;
import org.wlld.transFormer.model.TransFormerModel;
import org.wlld.transFormer.nerve.SensoryNerve;

/* loaded from: input_file:org/wlld/naturalLanguage/TalkToTalk.class */
public class TalkToTalk extends MatrixOperation {
    private final WordEmbedding wordEmbedding;
    private final TfConfig tfConfig;
    private final int maxLength;
    private final int times;
    private final String splitWord;
    private TransFormerManager transFormerManager;
    private final boolean splitAnswer;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/wlld/naturalLanguage/TalkToTalk$AnswerE.class */
    public static class AnswerE {
        List<Integer> answerList;
        Matrix answerMatrix;

        AnswerE() {
        }
    }

    public TalkToTalk(WordEmbedding wordEmbedding, TfConfig tfConfig) throws Exception {
        this.splitWord = tfConfig.getSplitWord();
        this.splitAnswer = (this.splitWord == null || this.splitWord.isEmpty()) ? false : true;
        this.wordEmbedding = wordEmbedding;
        this.tfConfig = tfConfig;
        this.maxLength = tfConfig.getMaxLength();
        this.times = tfConfig.getTimes();
        if (this.times <= 0) {
            throw new Exception("参数times必须大于0");
        }
    }

    public void init() throws Exception {
        this.tfConfig.setFeatureDimension(this.wordEmbedding.getWordVectorDimension());
        this.tfConfig.setTypeNumber(this.wordEmbedding.getWordList().size() + 2);
        this.transFormerManager = new TransFormerManager(this.tfConfig);
    }

    private Matrix insertStart(Matrix matrix, Matrix matrix2) throws Exception {
        Matrix matrix3 = new Matrix(matrix.getX() + 1, matrix.getY());
        int x = matrix3.getX();
        int y = matrix3.getY();
        for (int i = 0; i < x; i++) {
            for (int i2 = 0; i2 < y; i2++) {
                if (i > 0) {
                    matrix3.setNub(i, i2, matrix.getNumber(i - 1, i2));
                } else {
                    matrix3.setNub(i, i2, matrix2.getNumber(0, i2));
                }
            }
        }
        return matrix3;
    }

    public String getAnswer(String str, long j) throws Exception {
        SensoryNerve sensoryNerve = this.transFormerManager.getSensoryNerve();
        Matrix featureMatrix = this.wordEmbedding.getEmbedding(str, j, false).getFeatureMatrix();
        WordBack wordBack = new WordBack();
        StringBuilder sb = new StringBuilder();
        int i = 0;
        ArrayList arrayList = new ArrayList();
        do {
            Matrix copy = featureMatrix.copy();
            Matrix startMatrix = this.transFormerManager.getStartMatrix(featureMatrix);
            if (!arrayList.isEmpty()) {
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    startMatrix = pushVector(startMatrix, this.wordEmbedding.getEmbedding((String) it.next(), j, true).getFeatureMatrix(), true);
                }
            }
            i++;
            sensoryNerve.postMessage(j, copy, startMatrix, false, null, wordBack, false);
            int id = wordBack.getId();
            if (id > 1) {
                String word = this.wordEmbedding.getWord(id - 2);
                arrayList.add(word);
                if (this.splitAnswer) {
                    word = word + " ";
                }
                sb.append(word);
            }
            if (id <= 1) {
                break;
            }
        } while (i < this.maxLength);
        return sb.toString();
    }

    public void insertModel(TransFormerModel transFormerModel) throws Exception {
        this.transFormerManager.insertModel(transFormerModel);
    }

    private AnswerE getSentenceMatrix(String str) throws Exception {
        Matrix matrix = null;
        AnswerE answerE = new AnswerE();
        ArrayList arrayList = new ArrayList();
        if (this.splitAnswer) {
            String[] split = str.split(this.splitWord);
            if (split.length > this.maxLength) {
                split = (String[]) Arrays.copyOfRange(split, 0, this.maxLength);
            }
            for (String str2 : split) {
                arrayList.add(Integer.valueOf(this.wordEmbedding.getID(str2) + 2));
            }
            for (String str3 : split) {
                Matrix featureMatrix = this.wordEmbedding.getEmbedding(str3, 3L, true).getFeatureMatrix();
                matrix = matrix == null ? featureMatrix : pushVector(matrix, featureMatrix, true);
            }
        } else {
            if (str.length() > this.maxLength) {
                str = str.substring(0, this.maxLength);
            }
            for (int i = 0; i < str.length(); i++) {
                arrayList.add(Integer.valueOf(this.wordEmbedding.getID(str.substring(i, i + 1)) + 2));
            }
            matrix = this.wordEmbedding.getEmbedding(str, 3L, false).getFeatureMatrix();
        }
        arrayList.add(1);
        answerE.answerMatrix = matrix;
        answerE.answerList = arrayList;
        return answerE;
    }

    public TransFormerModel study(List<TalkBody> list) throws Exception {
        SensoryNerve sensoryNerve = this.transFormerManager.getSensoryNerve();
        int size = list.size();
        for (int i = 0; i < this.times; i++) {
            int i2 = 0;
            for (TalkBody talkBody : list) {
                i2++;
                String question = talkBody.getQuestion();
                String answer = talkBody.getAnswer();
                System.out.println("问题:" + question + ", 回答:" + answer + ",训练语句下标:" + i2 + ",总数量:" + size + ",当前次数：" + i + ",总次数:" + this.times);
                Matrix featureMatrix = this.wordEmbedding.getEmbedding(question, 1L, false).getFeatureMatrix();
                AnswerE sentenceMatrix = getSentenceMatrix(answer);
                sensoryNerve.postMessage(1L, featureMatrix, insertStart(sentenceMatrix.answerMatrix, this.transFormerManager.getStartMatrix(featureMatrix)), true, sentenceMatrix.answerList, null, false);
            }
        }
        return this.transFormerManager.getModel();
    }
}
