package com.omega.example.rnn.seq2seq;

import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.model.RNNCellType;
import com.omega.engine.nn.network.Seq2Seq;
import com.omega.engine.nn.network.Seq2SeqRNN;
import com.omega.engine.optimizer.EDOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.updater.UpdaterType;
import com.omega.example.rnn.data.IndexDataLoader;
import java.util.Scanner;

/* loaded from: input_file:com/omega/example/rnn/seq2seq/Seq2seq.class */
public class Seq2seq {
    public void seq2seq() {
        try {
            IndexDataLoader indexDataLoader = new IndexDataLoader("H:\\rnn_dataset\\translate.csv", 128);
            Seq2Seq seq2Seq = new Seq2Seq(RNNCellType.LSTM, LossType.softmax_with_cross_entropy, UpdaterType.adamw, indexDataLoader.max_en, indexDataLoader.max_ch - 1, 64, 512, indexDataLoader.en_characters, 128, 512, indexDataLoader.ch_characters);
            seq2Seq.CUDNN = true;
            seq2Seq.learnRate = 0.01f;
            EDOptimizer eDOptimizer = new EDOptimizer(seq2Seq, 128, 200, 0.001f, LearnRateUpdate.SMART_HALF, false);
            eDOptimizer.lr_step = new int[]{100};
            eDOptimizer.trainSeq2Seq(indexDataLoader);
            Scanner scanner = new Scanner(System.in);
            while (true) {
                System.out.println("请输入英文:");
                String nextLine = scanner.nextLine();
                if (nextLine.equals("exit")) {
                    scanner.close();
                    return;
                } else {
                    String lowerCase = nextLine.toLowerCase();
                    System.out.println(lowerCase);
                    eDOptimizer.predict(indexDataLoader, lowerCase);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void seq2seqRNN() {
        try {
            IndexDataLoader indexDataLoader = new IndexDataLoader("H:\\rnn_dataset\\translate4000.csv", 128);
            Seq2SeqRNN seq2SeqRNN = new Seq2SeqRNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, indexDataLoader.max_en, indexDataLoader.max_ch - 1, 64, 512, indexDataLoader.en_characters, 128, 512, indexDataLoader.ch_characters);
            seq2SeqRNN.CUDNN = true;
            seq2SeqRNN.learnRate = 0.01f;
            EDOptimizer eDOptimizer = new EDOptimizer(seq2SeqRNN, 128, 100, 0.001f, LearnRateUpdate.SMART_HALF, false);
            eDOptimizer.trainSeq2SeqRNN(indexDataLoader);
            Scanner scanner = new Scanner(System.in);
            while (true) {
                System.out.println("请输入英文:");
                String nextLine = scanner.nextLine();
                if (nextLine.equals("exit")) {
                    scanner.close();
                    return;
                } else {
                    String lowerCase = nextLine.toLowerCase();
                    System.out.println(lowerCase);
                    eDOptimizer.predictRNN(indexDataLoader, lowerCase);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        try {
            try {
                CUDAModules.initContext();
                new Seq2seq().seq2seq();
                CUDAMemoryManager.free();
            } catch (Exception e) {
                e.printStackTrace();
                CUDAMemoryManager.free();
            }
        } catch (Throwable th) {
            CUDAMemoryManager.free();
            throw th;
        }
    }
}
