package com.omega.example.transformer.test;

import com.omega.common.data.Tensor;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.gpu.SoftmaxKernel;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.network.GPT;
import com.omega.engine.nn.network.GPT2;
import com.omega.engine.nn.network.NanoGPT;
import com.omega.engine.nn.network.RNN;
import com.omega.engine.nn.network.RunModel;
import com.omega.engine.optimizer.EDOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.updater.UpdaterType;
import com.omega.example.transformer.utils.BPETokenizer;
import com.omega.example.transformer.utils.CNChatTokenizer;
import com.omega.example.transformer.utils.CNChatTokenizer2;
import com.omega.example.transformer.utils.CNTokenizer;
import com.omega.example.transformer.utils.ENTokenizer;
import java.util.Arrays;
import java.util.Map;
import java.util.Scanner;

/* loaded from: input_file:com/omega/example/transformer/test/GPTTest.class */
public class GPTTest {
    private static SoftmaxKernel kernel;

    public static void gpt() {
        try {
            ENTokenizer eNTokenizer = new ENTokenizer("H:\\transformer_dataset\\gpt\\wikitext-2-v1\\wikitext-2\\wiki.train.tokens", 128, 32);
            NanoGPT nanoGPT = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 12, 12, eNTokenizer.vocab_size, 128, 768, false, false, false);
            nanoGPT.learnRate = 1.0E-4f;
            new EDOptimizer(nanoGPT, 32, 100, 0.001f, LearnRateUpdate.GD_GECAY, false).trainGPT(eNTokenizer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void gpt_lang() {
        try {
            ENTokenizer eNTokenizer = new ENTokenizer("H:\\transformer_dataset\\gpt\\lang\\lang.txt", 256, 10);
            GPT gpt = new GPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, eNTokenizer.vocab_size, 256, 512, 2048);
            gpt.CUDNN = true;
            gpt.learnRate = 1.0E-4f;
            new EDOptimizer(gpt, 10, 300, 0.001f, LearnRateUpdate.CONSTANT, false).trainGPT(eNTokenizer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void ch_chat() {
        try {
            CNChatTokenizer cNChatTokenizer = new CNChatTokenizer("H:\\transformer_dataset\\gpt\\chatdata\\train-format1w.txt", 128, 32);
            GPT gpt = new GPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, cNChatTokenizer.vocab_size, 128, 768, 2048);
            gpt.CUDNN = true;
            gpt.learnRate = 0.01f;
            new EDOptimizer(gpt, 32, 300, 1.0E-4f, LearnRateUpdate.GD_GECAY, false).trainGPT(cNChatTokenizer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void ch_chat_gpt2() {
        try {
            CNChatTokenizer cNChatTokenizer = new CNChatTokenizer("H:\\transformer_dataset\\gpt\\chatdata\\train-format20w.txt", 128, 32);
            NanoGPT nanoGPT = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 12, 12, cNChatTokenizer.vocab_size, 128, 768, false, false, false);
            nanoGPT.learnRate = 1.0E-4f;
            EDOptimizer eDOptimizer = new EDOptimizer(nanoGPT, 32, 3, 1.0E-4f, LearnRateUpdate.SMART_HALF, false);
            eDOptimizer.lr_step = new int[]{1, 2};
            eDOptimizer.trainNanoGPT(cNChatTokenizer);
            Scanner scanner = new Scanner(System.in);
            String str = "";
            while (true) {
                System.out.println("请输入中文:");
                String nextLine = scanner.nextLine();
                if (nextLine.equals("clean")) {
                    str = "";
                } else {
                    if (nextLine.equals("exit")) {
                        scanner.close();
                        return;
                    }
                    String str2 = nextLine.toLowerCase() + " ";
                    System.out.println("user:" + str2);
                    String str3 = str + str2;
                    Tensor loadByTxtToIdx = cNChatTokenizer.loadByTxtToIdx(str3);
                    Tensor positions = CNChatTokenizer.getPositions(1, loadByTxtToIdx.number);
                    for (int i = 0; i < 128; i++) {
                        nanoGPT.time = loadByTxtToIdx.number;
                        Tensor forward = nanoGPT.forward(loadByTxtToIdx, positions);
                        forward.syncHost();
                        String output2TXT = output2TXT(forward, cNChatTokenizer, true);
                        String substring = output2TXT.substring(output2TXT.length() - 1, str3.length());
                        if (CNChatTokenizer.sd.get(substring) != null && (CNChatTokenizer.sd.get(substring).equals("<sep>") || CNChatTokenizer.sd.get(substring).equals("<eos>"))) {
                            str3 = str3 + substring;
                            break;
                        }
                        str3 = str3 + substring;
                        loadByTxtToIdx = cNChatTokenizer.loadByTxtToIdx(str3);
                        CNChatTokenizer.getPositions(1, loadByTxtToIdx.number, positions);
                    }
                    String[] split = str3.split(" ");
                    String str4 = split[split.length - 1];
                    System.out.println("chatbot:" + str4);
                    str = str + str3 + str4;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void ch_chat_gpt2_voc() {
        try {
            BPETokenizer bPETokenizer = new BPETokenizer("H:\\transformer_dataset\\gpt\\50w_vocab.json", "H:\\transformer_dataset\\gpt\\50w_decode_vocab.json");
            CNChatTokenizer2 cNChatTokenizer2 = new CNChatTokenizer2("H:\\transformer_dataset\\gpt\\50w.txt", 128, 16, bPETokenizer);
            NanoGPT nanoGPT = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 8, 8, cNChatTokenizer2.vocab_size, 128, 512, false, false, false);
            nanoGPT.learnRate = 1.0E-4f;
            EDOptimizer eDOptimizer = new EDOptimizer(nanoGPT, 16, 1, 1.0E-4f, LearnRateUpdate.SMART_HALF, false);
            eDOptimizer.lr_step = new int[]{1, 2};
            eDOptimizer.trainNanoGPT(cNChatTokenizer2);
            Scanner scanner = new Scanner(System.in);
            String str = "";
            while (true) {
                System.out.println("请输入中文:");
                String nextLine = scanner.nextLine();
                if (nextLine.equals("clean")) {
                    str = "";
                } else {
                    if (nextLine.equals("exit")) {
                        scanner.close();
                        return;
                    }
                    String str2 = nextLine.toLowerCase() + " ";
                    System.out.println("user:" + str2);
                    String str3 = str + str2;
                    Tensor loadByTxtToIdx = cNChatTokenizer2.loadByTxtToIdx(str3);
                    Tensor positions = CNChatTokenizer.getPositions(1, loadByTxtToIdx.number);
                    for (int i = 0; i < 128; i++) {
                        nanoGPT.time = loadByTxtToIdx.number;
                        Tensor forward = nanoGPT.forward(loadByTxtToIdx, positions);
                        forward.syncHost();
                        String text = bPETokenizer.toText(forward);
                        System.out.println("output:" + text);
                        String substring = text.substring(text.length() - 1, str3.length());
                        str3 = str3 + substring;
                        if (substring.equals(" ") || substring.equals("<pad>")) {
                            break;
                        }
                        loadByTxtToIdx = cNChatTokenizer2.loadByTxtToIdx(str3);
                        CNChatTokenizer.getPositions(1, loadByTxtToIdx.number, positions);
                    }
                    String[] split = str3.split(" ");
                    String str4 = split[split.length - 1];
                    System.out.println("chatbot:" + str4);
                    str = str + str3 + str4;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void gpt2_yl_qa() {
        try {
            CNChatTokenizer cNChatTokenizer = new CNChatTokenizer("H:\\transformer_dataset\\gpt\\cMedQA2\\qaData.txt", 128, 16);
            NanoGPT nanoGPT = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 12, 6, cNChatTokenizer.vocab_size, 128, 768, false, true, false);
            nanoGPT.learnRate = 0.001f;
            new EDOptimizer(nanoGPT, 16, 1, 1.0E-4f, LearnRateUpdate.SMART_HALF, false).trainNanoGPT(cNChatTokenizer);
            nanoGPT.RUN_MODEL = RunModel.TEST;
            Scanner scanner = new Scanner(System.in);
            while (true) {
                System.out.println("请输入中文:");
                String nextLine = scanner.nextLine();
                if (nextLine.equals("exit")) {
                    scanner.close();
                    return;
                }
                String str = nextLine.toLowerCase() + " ";
                System.out.println("user:" + str);
                Tensor loadByTxtToIdx = cNChatTokenizer.loadByTxtToIdx(str);
                Tensor positions = CNChatTokenizer.getPositions(1, loadByTxtToIdx.number);
                for (int i = 0; i < 128; i++) {
                    nanoGPT.time = loadByTxtToIdx.number;
                    Tensor forward = nanoGPT.forward(loadByTxtToIdx, positions);
                    forward.syncHost();
                    String output2TXT = output2TXT(forward, cNChatTokenizer, true);
                    String substring = output2TXT.substring(output2TXT.length() - 1, str.length());
                    if (CNChatTokenizer.sd.get(substring) != null && (CNChatTokenizer.sd.get(substring).equals("<sep>") || CNChatTokenizer.sd.get(substring).equals("<eos>"))) {
                        str = str + CNChatTokenizer.sd.get(substring);
                        break;
                    }
                    str = str + substring;
                    loadByTxtToIdx = cNChatTokenizer.loadByTxtToIdx(str);
                    CNChatTokenizer.getPositions(1, loadByTxtToIdx.number, positions);
                }
                System.out.println("chatbot:" + str.split(" ")[1]);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void gpt2_lang() {
        try {
            ENTokenizer eNTokenizer = new ENTokenizer("H:\\transformer_dataset\\gpt\\lang\\lang.txt", 256, 10);
            GPT2 gpt2 = new GPT2(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 6, 8, eNTokenizer.vocab_size, 256, 128, false, false);
            gpt2.CUDNN = true;
            gpt2.learnRate = 1.0E-4f;
            new EDOptimizer(gpt2, 10, 300, 0.001f, LearnRateUpdate.CONSTANT, false).trainGPT2(eNTokenizer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void gpt2_gan() {
        try {
            ENTokenizer eNTokenizer = new ENTokenizer("H:\\transformer_dataset\\gpt\\wikitext-2-v1\\wikitext-2\\wiki.train.tokens", 128, 12);
            NanoGPT nanoGPT = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 4, 6, eNTokenizer.vocab_size, 128, 512, false, false);
            nanoGPT.CUDNN = true;
            nanoGPT.learnRate = 1.0E-4f;
            new EDOptimizer(nanoGPT, 12, 300, 0.001f, LearnRateUpdate.CONSTANT, false).trainNanoGPT(eNTokenizer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void nano_gpt_lang() {
        try {
            ENTokenizer eNTokenizer = new ENTokenizer("H:\\transformer_dataset\\gpt\\lang\\lang.txt", 128, 10);
            NanoGPT nanoGPT = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 6, 4, eNTokenizer.vocab_size, 128, 512, false, false);
            nanoGPT.learnRate = 1.0E-4f;
            new EDOptimizer(nanoGPT, 10, 300, 0.001f, LearnRateUpdate.CONSTANT, false).trainNanoGPT(eNTokenizer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void gpt_dp() {
        try {
            CNTokenizer cNTokenizer = new CNTokenizer("H:\\transformer_dataset\\gpt\\dpcc50.txt", 64, 32);
            NanoGPT nanoGPT = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 8, 8, cNTokenizer.characters, 64, 128, false, true);
            nanoGPT.learnRate = 0.001f;
            new EDOptimizer(nanoGPT, 32, 3, 0.001f, LearnRateUpdate.GD_GECAY, false).trainNanoGPT_GEN(cNTokenizer);
            nanoGPT.RUN_MODEL = RunModel.TEST;
            String str = "萧炎";
            Tensor positions = CNChatTokenizer.getPositions(1, str.length());
            Tensor triu = CNChatTokenizer.triu(1, nanoGPT.headNum, str.length(), str.length(), 1.0f);
            Tensor createTxtData = createTxtData(null, str, cNTokenizer.characters, cNTokenizer.dictionary, 64);
            for (int i = 0; i < 1000; i++) {
                nanoGPT.time = createTxtData.number;
                String genTxt = genTxt(createTxtData, null, nanoGPT, cNTokenizer, str.length(), triu, positions);
                str = nanoGPT.time > 1 ? str + genTxt.substring(createTxtData.number - 1, createTxtData.number) : str + genTxt;
                createTxtData = createTxtData(createTxtData, str, cNTokenizer.characters, cNTokenizer.dictionary, 64);
            }
            System.out.println(str);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void gpt_ssby() {
        try {
            CNTokenizer cNTokenizer = new CNTokenizer("H:\\transformer_dataset\\gpt\\ssby\\ssby.txt", 64, 32);
            NanoGPT nanoGPT = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 4, 4, cNTokenizer.characters, 64, 128, false, true);
            nanoGPT.learnRate = 0.001f;
            new EDOptimizer(nanoGPT, 32, 1, 0.001f, LearnRateUpdate.GD_GECAY, false).trainNanoGPT_GEN(cNTokenizer);
            nanoGPT.RUN_MODEL = RunModel.TEST;
            kernel = new SoftmaxKernel();
            String str = " ";
            Tensor positions = CNChatTokenizer.getPositions(1, str.length());
            Tensor triu = CNChatTokenizer.triu(1, nanoGPT.headNum, str.length(), str.length(), 1.0f);
            Tensor createTxtData = createTxtData(null, str, cNTokenizer.characters, cNTokenizer.dictionary, 64);
            createTxtData.shape();
            positions.shape();
            for (int i = 0; i < 1000; i++) {
                nanoGPT.time = createTxtData.number;
                String genTxt = genTxt(createTxtData, null, nanoGPT, cNTokenizer, str.length(), triu, positions);
                System.out.println("output txt=" + genTxt);
                str = nanoGPT.time > 1 ? str + genTxt.substring(createTxtData.number - 1, createTxtData.number) : str + genTxt;
                System.out.println(str);
                createTxtData = createTxtData(createTxtData, str, cNTokenizer.characters, cNTokenizer.dictionary, 64);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static String genTxt(Tensor tensor, Tensor tensor2, RNN rnn, CNChatTokenizer cNChatTokenizer, int i) {
        Tensor forward = rnn.forward(tensor);
        forward.syncHost();
        return output2TXT(forward, cNChatTokenizer);
    }

    public static String output2TXT(Tensor tensor, CNChatTokenizer cNChatTokenizer) {
        String str = "";
        for (int i = 0; i < tensor.number; i++) {
            str = str + cNChatTokenizer.vocab[pickTopN(tensor.getByNumber(i), 3)];
        }
        return str;
    }

    public static String output2TXT(Tensor tensor, CNChatTokenizer cNChatTokenizer, boolean z) {
        String str = "";
        for (int i = 0; i < tensor.number; i++) {
            str = str + cNChatTokenizer.vocab[pickTopN(tensor.getByNumber(i), 1)];
        }
        if (z) {
            for (String str2 : CNChatTokenizer.specials_dictionary.keySet()) {
                str = str.replaceAll(str2, CNChatTokenizer.specials_dictionary.get(str2));
            }
        }
        return str;
    }

    public static int pickTopN(float[] fArr, int i) {
        float[] copyOf = Arrays.copyOf(fArr, fArr.length);
        Arrays.sort(copyOf);
        float[] copyOfRange = Arrays.copyOfRange(copyOf, copyOf.length - i, copyOf.length);
        float f = copyOfRange[RandomUtils.getRandomNumber(copyOfRange)];
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (f == fArr[i2]) {
                return i2;
            }
        }
        return 0;
    }

    public static Tensor createTxtData(Tensor tensor, String str, int i, Map<Character, Integer> map, int i2) {
        int length = str.length();
        if (str.length() > i2) {
            length = i2;
        }
        char[] cArr = new char[length];
        int length2 = str.length() - i2;
        if (length2 <= 0) {
            length2 = 0;
        }
        str.getChars(length2, str.length(), cArr, 0);
        float[] fArr = new float[length];
        for (int i3 = 0; i3 < length; i3++) {
            fArr[i3] = map.get(Character.valueOf(cArr[i3])).intValue();
        }
        if (tensor == null || tensor.number != cArr.length) {
            tensor = Tensor.createTensor(tensor, cArr.length, 1, 1, 1, fArr, true);
        } else {
            tensor.data = fArr;
            tensor.hostToDevice();
        }
        return tensor;
    }

    public static String genTxt(Tensor tensor, Tensor tensor2, Tensor tensor3, NanoGPT nanoGPT, CNTokenizer cNTokenizer, int i, Tensor tensor4, Tensor tensor5) {
        CNChatTokenizer.getPositions(1, i, tensor5);
        CNChatTokenizer.triu(1, nanoGPT.headNum, i, i, 1.0f, tensor4);
        nanoGPT.time = i;
        Tensor forward = nanoGPT.forward(tensor, tensor5, tensor4);
        Tensor createTensor = Tensor.createTensor(tensor3, tensor.number, tensor.channel, tensor.height, tensor.width, true);
        kernel.softmax_out(forward, createTensor);
        createTensor.syncHost();
        return output2TXT(createTensor, cNTokenizer);
    }

    public static String genTxt(Tensor tensor, Tensor tensor2, NanoGPT nanoGPT, CNTokenizer cNTokenizer, int i, Tensor tensor3, Tensor tensor4) {
        CNChatTokenizer.getPositions(1, tensor.number, tensor4);
        CNChatTokenizer.triu(1, nanoGPT.headNum, tensor.number, tensor.number, 1.0f, tensor3);
        nanoGPT.time = tensor.number;
        Tensor forward = nanoGPT.forward(tensor, tensor4, tensor3);
        forward.syncHost();
        return output2TXT(forward, cNTokenizer);
    }

    public static String output2TXT(Tensor tensor, CNTokenizer cNTokenizer) {
        String str = "";
        for (int i = 0; i < tensor.number; i++) {
            str = str + cNTokenizer.dictionaryData[pickTopN(tensor.getByNumber(i), 1)].charValue();
        }
        return str;
    }

    public static void main(String[] strArr) {
        try {
            CUDAModules.initContext();
            gpt2_yl_qa();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            CUDAMemoryManager.free();
        }
    }
}
