package com.omega.example.rnn.test;

import com.omega.common.data.Tensor;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.active.ActiveType;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.layer.EmbeddingLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.LSTMLayer;
import com.omega.engine.nn.layer.RNNBlockLayer;
import com.omega.engine.nn.layer.RNNLayer;
import com.omega.engine.nn.layer.active.LeakyReluLayer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.layer.normalization.LNLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.RNN;
import com.omega.engine.nn.network.RunModel;
import com.omega.engine.optimizer.MBSGDOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.updater.UpdaterType;
import com.omega.example.rnn.data.OneHotDataLoader;
import java.util.Arrays;
import java.util.Map;

/* loaded from: input_file:com/omega/example/rnn/test/CharRNN.class */
public class CharRNN {
    public void charRNN() {
        try {
            OneHotDataLoader oneHotDataLoader = new OneHotDataLoader("H:\\rnn_dataset\\dpcc.txt", 256, 64);
            RNN rnn = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 256);
            InputLayer inputLayer = new InputLayer(1, 1, oneHotDataLoader.characters);
            EmbeddingLayer embeddingLayer = new EmbeddingLayer(oneHotDataLoader.characters, 256);
            RNNLayer rNNLayer = new RNNLayer(256, 512, 256, ActiveType.tanh, false, rnn);
            RNNLayer rNNLayer2 = new RNNLayer(512, 512, 256, ActiveType.tanh, false, rnn);
            RNNLayer rNNLayer3 = new RNNLayer(512, 512, 256, ActiveType.tanh, false, rnn);
            FullyLayer fullyLayer = new FullyLayer(512, 512, false);
            BNLayer bNLayer = new BNLayer();
            LeakyReluLayer leakyReluLayer = new LeakyReluLayer();
            FullyLayer fullyLayer2 = new FullyLayer(512, oneHotDataLoader.characters, true);
            rnn.addLayer(inputLayer);
            rnn.addLayer(embeddingLayer);
            rnn.addLayer(rNNLayer);
            rnn.addLayer(rNNLayer2);
            rnn.addLayer(rNNLayer3);
            rnn.addLayer(fullyLayer);
            rnn.addLayer(bNLayer);
            rnn.addLayer(leakyReluLayer);
            rnn.addLayer(fullyLayer2);
            rnn.CUDNN = true;
            rnn.learnRate = 0.01f;
            new MBSGDOptimizer((Network) rnn, 2, 0.001f, 64, LearnRateUpdate.POLY, false).trainRNN(oneHotDataLoader);
            String str = "这个故事所造成的后果，便是造就了大批每天";
            Tensor createTxtData = createTxtData(null, str, oneHotDataLoader.characters, oneHotDataLoader.dictionary, 256);
            rnn.RUN_MODEL = RunModel.TEST;
            for (int i = 0; i < 1000; i++) {
                rnn.time = createTxtData.number;
                String genTxt = genTxt(createTxtData, null, rnn, oneHotDataLoader, 256);
                str = rnn.time > 1 ? str + genTxt.substring(createTxtData.number - 1, createTxtData.number) : str + genTxt;
                createTxtData = createTxtData(createTxtData, str, oneHotDataLoader.characters, oneHotDataLoader.dictionary, 256);
            }
            System.out.println(str);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void charLSTM() {
        try {
            OneHotDataLoader oneHotDataLoader = new OneHotDataLoader("H:\\rnn_dataset\\dpcc50.txt", 256, 64);
            RNN rnn = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 256);
            InputLayer inputLayer = new InputLayer(1, 1, oneHotDataLoader.characters);
            EmbeddingLayer embeddingLayer = new EmbeddingLayer(oneHotDataLoader.characters, 256);
            LSTMLayer lSTMLayer = new LSTMLayer(256, 512, 256, true, rnn);
            FullyLayer fullyLayer = new FullyLayer(512, 512, false);
            LNLayer lNLayer = new LNLayer();
            LeakyReluLayer leakyReluLayer = new LeakyReluLayer();
            FullyLayer fullyLayer2 = new FullyLayer(512, oneHotDataLoader.characters, true);
            rnn.addLayer(inputLayer);
            rnn.addLayer(embeddingLayer);
            rnn.addLayer(lSTMLayer);
            rnn.addLayer(fullyLayer);
            rnn.addLayer(lNLayer);
            rnn.addLayer(leakyReluLayer);
            rnn.addLayer(fullyLayer2);
            rnn.CUDNN = true;
            rnn.learnRate = 0.01f;
            new MBSGDOptimizer((Network) rnn, 5, 0.001f, 64, LearnRateUpdate.CONSTANT, false).trainRNN(oneHotDataLoader);
            String str = "萧";
            Tensor createTxtData = createTxtData(null, str, oneHotDataLoader.characters, oneHotDataLoader.dictionary, 100);
            rnn.RUN_MODEL = RunModel.TEST;
            for (int i = 0; i < 1000; i++) {
                rnn.time = createTxtData.number;
                String genTxt = genTxt(createTxtData, null, rnn, oneHotDataLoader, 100);
                str = rnn.time > 1 ? str + genTxt.substring(createTxtData.number - 1, createTxtData.number) : str + genTxt;
                createTxtData = createTxtData(createTxtData, str, oneHotDataLoader.characters, oneHotDataLoader.dictionary, 100);
            }
            System.out.println(str);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void RNN() {
        try {
            Tensor tensor = new Tensor(3 * 2, 1, 1, 3, RandomUtils.order(3 * 2 * 3, 0.1f, 0.0f), true);
            RNN rnn = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 3);
            InputLayer inputLayer = new InputLayer(1, 1, 3);
            RNNLayer rNNLayer = new RNNLayer(3, 5, 3, ActiveType.tanh, true, rnn);
            rnn.addLayer(inputLayer);
            rnn.addLayer(rNNLayer);
            rnn.CUDNN = true;
            rnn.learnRate = 0.002f;
            new MBSGDOptimizer((Network) rnn, 500, 0.001f, 2, LearnRateUpdate.POLY, false).testRNN(tensor);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void LSTM() {
        try {
            Tensor tensor = new Tensor(3 * 2, 1, 1, 3, RandomUtils.order(3 * 2 * 3, 0.1f, 0.0f), true);
            RNN rnn = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 3);
            InputLayer inputLayer = new InputLayer(1, 1, 3);
            LSTMLayer lSTMLayer = new LSTMLayer(3, 5, 3, true, rnn);
            rnn.addLayer(inputLayer);
            rnn.addLayer(lSTMLayer);
            rnn.CUDNN = true;
            rnn.learnRate = 0.002f;
            new MBSGDOptimizer((Network) rnn, 500, 0.001f, 2, LearnRateUpdate.POLY, false).testRNN(tensor);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void RNN_CUDNN() {
        try {
            Tensor tensor = new Tensor(3 * 2, 1, 1, 3, RandomUtils.order(3 * 2 * 3, 0.1f, 0.0f), true);
            RNN rnn = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 3);
            InputLayer inputLayer = new InputLayer(1, 1, 3);
            RNNBlockLayer rNNBlockLayer = new RNNBlockLayer(3, 1, 3, 5, 1, false, false, 0.0f, rnn);
            rnn.addLayer(inputLayer);
            rnn.addLayer(rNNBlockLayer);
            rnn.CUDNN = true;
            rnn.learnRate = 0.002f;
            new MBSGDOptimizer((Network) rnn, 500, 0.001f, 2, LearnRateUpdate.POLY, false).testRNN(tensor);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void charRNN2() {
        try {
            OneHotDataLoader oneHotDataLoader = new OneHotDataLoader("H:\\rnn_dataset\\shakespeare.txt", 576, 64);
            RNN rnn = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 576);
            InputLayer inputLayer = new InputLayer(1, 1, oneHotDataLoader.characters);
            RNNBlockLayer rNNBlockLayer = new RNNBlockLayer(576, 1, oneHotDataLoader.characters, 1024, 1, false, false, 0.0f, rnn);
            FullyLayer fullyLayer = new FullyLayer(1024, oneHotDataLoader.characters, true);
            rnn.addLayer(inputLayer);
            rnn.addLayer(rNNBlockLayer);
            rnn.addLayer(fullyLayer);
            rnn.CUDNN = true;
            rnn.learnRate = 0.001f;
            new MBSGDOptimizer((Network) rnn, 3, 0.001f, 64, LearnRateUpdate.POLY, false).trainRNN(oneHotDataLoader);
            String str = "All:";
            Tensor createTxtData = createTxtData(null, str, oneHotDataLoader.characters, oneHotDataLoader.dictionary, 100);
            for (int i = 0; i < 1000; i++) {
                rnn.time = createTxtData.number;
                String genTxt = genTxt(createTxtData, null, rnn, oneHotDataLoader, 100);
                str = rnn.time > 1 ? str + genTxt.substring(createTxtData.number - 1, createTxtData.number) : str + genTxt;
                createTxtData = createTxtData(createTxtData, str, oneHotDataLoader.characters, oneHotDataLoader.dictionary, 100);
            }
            System.out.println(str);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void charRNN3() {
        try {
            OneHotDataLoader oneHotDataLoader = new OneHotDataLoader("H:\\rnn_dataset\\dpcc50.txt", 256, 64);
            RNN rnn = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, 256);
            InputLayer inputLayer = new InputLayer(1, 1, oneHotDataLoader.characters);
            EmbeddingLayer embeddingLayer = new EmbeddingLayer(oneHotDataLoader.characters, 256);
            RNNBlockLayer rNNBlockLayer = new RNNBlockLayer(256, 1, 256, 512, 2, false, false, 0.5f, rnn);
            FullyLayer fullyLayer = new FullyLayer(512, 512, false);
            BNLayer bNLayer = new BNLayer();
            LeakyReluLayer leakyReluLayer = new LeakyReluLayer();
            FullyLayer fullyLayer2 = new FullyLayer(512, oneHotDataLoader.characters, true);
            rnn.addLayer(inputLayer);
            rnn.addLayer(embeddingLayer);
            rnn.addLayer(rNNBlockLayer);
            rnn.addLayer(fullyLayer);
            rnn.addLayer(bNLayer);
            rnn.addLayer(leakyReluLayer);
            rnn.addLayer(fullyLayer2);
            rnn.CUDNN = true;
            rnn.learnRate = 0.001f;
            MBSGDOptimizer mBSGDOptimizer = new MBSGDOptimizer((Network) rnn, 2, 0.001f, 64, LearnRateUpdate.SMART_HALF, false);
            mBSGDOptimizer.lr_step = new int[]{5, 8, 10};
            mBSGDOptimizer.trainRNN(oneHotDataLoader);
            String str = "萧";
            rnn.RUN_MODEL = RunModel.TEST;
            Tensor createTxtData = createTxtData(null, str, oneHotDataLoader.characters, oneHotDataLoader.dictionary, 100);
            for (int i = 0; i < 1000; i++) {
                rnn.time = createTxtData.number;
                String genTxt = genTxt(createTxtData, null, rnn, oneHotDataLoader, 100);
                str = rnn.time > 1 ? str + genTxt.substring(createTxtData.number - 1, createTxtData.number) : str + genTxt;
                createTxtData = createTxtData(createTxtData, str, oneHotDataLoader.characters, oneHotDataLoader.dictionary, 100);
            }
            System.out.println(str);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void createTxtData(String str, int i, Map<Character, Integer> map, Tensor tensor) {
        char[] cArr = new char[str.length()];
        str.getChars(0, str.length(), cArr, 0);
        float[] fArr = new float[cArr.length * i];
        for (int i2 = 0; i2 < str.length(); i2++) {
            fArr[(i2 * i) + map.get(Character.valueOf(cArr[i2])).intValue()] = 1.0f;
        }
        tensor.number = cArr.length;
        tensor.data = fArr;
        tensor.hostToDevice();
    }

    public static Tensor createTxtData(Tensor tensor, String str, int i, Map<Character, Integer> map, int i2) {
        int length = str.length();
        if (str.length() > i2) {
            length = i2;
        }
        char[] cArr = new char[length];
        int length2 = str.length() - i2;
        if (length2 <= 0) {
            length2 = 0;
        }
        str.getChars(length2, str.length(), cArr, 0);
        float[] fArr = new float[length * i];
        for (int i3 = 0; i3 < length; i3++) {
            fArr[(i3 * i) + map.get(Character.valueOf(cArr[i3])).intValue()] = 1.0f;
        }
        return Tensor.createTensor(tensor, cArr.length, 1, 1, i, fArr, true);
    }

    public static String genTxt(Tensor tensor, Tensor tensor2, RNN rnn, OneHotDataLoader oneHotDataLoader, int i) {
        Tensor forward = rnn.forward(tensor);
        forward.syncHost();
        return output2TXT(forward, oneHotDataLoader);
    }

    public static String output2TXT(Tensor tensor, OneHotDataLoader oneHotDataLoader) {
        String str = "";
        for (int i = 0; i < tensor.number; i++) {
            str = str + oneHotDataLoader.dictionaryData[pickTopN(tensor.getByNumber(i), 3)].charValue();
        }
        return str;
    }

    public static int pickTopN(float[] fArr, int i) {
        float[] copyOf = Arrays.copyOf(fArr, fArr.length);
        Arrays.sort(copyOf);
        float[] copyOfRange = Arrays.copyOfRange(copyOf, copyOf.length - i, copyOf.length);
        float f = copyOfRange[RandomUtils.getRandomNumber(copyOfRange)];
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (f == fArr[i2]) {
                return i2;
            }
        }
        return 0;
    }

    public static void main(String[] strArr) {
        try {
            try {
                CUDAModules.initContext();
                new CharRNN().charLSTM();
                CUDAMemoryManager.free();
            } catch (Exception e) {
                e.printStackTrace();
                CUDAMemoryManager.free();
            }
        } catch (Throwable th) {
            CUDAMemoryManager.free();
            throw th;
        }
    }
}
