package com.omega.engine.optimizer;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MathUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.nn.data.BaseData;
import com.omega.engine.nn.grad.GradClipping;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.network.GPT;
import com.omega.engine.nn.network.GPT2;
import com.omega.engine.nn.network.NanoGPT;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.RunModel;
import com.omega.engine.nn.network.Seq2Seq;
import com.omega.engine.nn.network.Seq2SeqRNN;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.example.rnn.data.IndexDataLoader;
import com.omega.example.transformer.utils.CNChatTokenizer;
import com.omega.example.transformer.utils.CNChatTokenizer2;
import com.omega.example.transformer.utils.CNTokenizer;
import com.omega.example.transformer.utils.ENTokenizer;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import jcuda.driver.JCudaDriver;

/* loaded from: input_file:com/omega/engine/optimizer/EDOptimizer.class */
public class EDOptimizer extends Optimizer {
    private float clamp_val;

    public EDOptimizer(Network network, int i, int i2, float f, LearnRateUpdate learnRateUpdate, boolean z) throws Exception {
        super(network, i, i2, f, z);
        this.clamp_val = -100.0f;
        this.batchSize = i;
        this.trainTime = i2;
        this.learnRateUpdate = learnRateUpdate;
        this.loss = new Tensor(i, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.lossDiff = new Tensor(i, this.network.oChannel, this.network.oHeight, this.network.oWidth);
    }

    public EDOptimizer(Network network, int i, int i2, float f, boolean z) throws Exception {
        super(network, i, i2, f, z);
        this.clamp_val = -100.0f;
    }

    @Override // com.omega.engine.optimizer.Optimizer
    public void train(BaseData baseData) {
    }

    public void trainSeq2SeqRNN(IndexDataLoader indexDataLoader) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = indexDataLoader.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Seq2SeqRNN seq2SeqRNN = (Seq2SeqRNN) this.network;
            Tensor tensor = new Tensor(seq2SeqRNN.en_time * this.batchSize, 1, 1, seq2SeqRNN.en_len, true);
            Tensor tensor2 = new Tensor(seq2SeqRNN.de_time * this.batchSize, 1, 1, seq2SeqRNN.de_len, true);
            Tensor tensor3 = new Tensor(seq2SeqRNN.de_time * this.batchSize, 1, 1, seq2SeqRNN.de_len, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(indexDataLoader.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    indexDataLoader.loadData(randomInts[i2], tensor, tensor2, tensor3);
                    Tensor forward = seq2SeqRNN.forward(tensor, tensor2);
                    this.loss = seq2SeqRNN.loss(forward, tensor3);
                    this.lossDiff = seq2SeqRNN.lossDiff(forward, tensor3);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor2.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor2.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracy(forward, tensor3, forward.number / this.batchSize, this.batchSize) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainSeq2Seq(IndexDataLoader indexDataLoader) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = indexDataLoader.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Seq2Seq seq2Seq = (Seq2Seq) this.network;
            Tensor tensor = new Tensor(seq2Seq.en_time * this.batchSize, 1, 1, seq2Seq.en_len, true);
            Tensor tensor2 = new Tensor(seq2Seq.de_time * this.batchSize, 1, 1, seq2Seq.de_len, true);
            Tensor tensor3 = new Tensor(seq2Seq.de_time * this.batchSize, 1, 1, seq2Seq.de_len, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(indexDataLoader.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    indexDataLoader.loadData(randomInts[i2], tensor, tensor2, tensor3);
                    Tensor forward = seq2Seq.forward(tensor, tensor2);
                    this.loss = seq2Seq.loss(forward, tensor3);
                    this.lossDiff = seq2Seq.lossDiff(forward, tensor3);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor2.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor2.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracy(forward, tensor3, forward.number / this.batchSize, this.batchSize) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainGPT(ENTokenizer eNTokenizer) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = eNTokenizer.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            GPT gpt = (GPT) this.network;
            Tensor tensor = new Tensor(this.batchSize * gpt.time, 1, 1, gpt.vocab_size, true);
            Tensor tensor2 = new Tensor(this.batchSize * gpt.time, 1, 1, gpt.vocab_size, true);
            Tensor triu = ENTokenizer.triu(this.batchSize, gpt.head_num, gpt.time, gpt.time, 1.0f);
            Tensor positions = ENTokenizer.getPositions(this.batchSize, gpt.time);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(eNTokenizer.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    eNTokenizer.loadData(randomInts[i2], tensor, tensor2);
                    Tensor forward = gpt.forward(tensor, positions, triu);
                    this.loss = gpt.loss(forward, tensor2, eNTokenizer.dictionary.get("<pad>").intValue());
                    this.lossDiff = gpt.lossDiff(forward, tensor2, eNTokenizer.dictionary.get("<pad>").intValue());
                    this.network.back(this.lossDiff);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor2, forward.number / this.batchSize, this.batchSize, eNTokenizer.vocab, eNTokenizer.dictionary.get("<pad>").intValue()) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainGPT2(ENTokenizer eNTokenizer) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = eNTokenizer.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            GPT2 gpt2 = (GPT2) this.network;
            Tensor tensor = new Tensor(this.batchSize * gpt2.time, 1, 1, gpt2.vocabSize, true);
            Tensor tensor2 = new Tensor(this.batchSize * gpt2.time, 1, 1, gpt2.vocabSize, true);
            Tensor triu = ENTokenizer.triu(this.batchSize, gpt2.headNum, gpt2.time, gpt2.time, 1.0f);
            Tensor positions = ENTokenizer.getPositions(this.batchSize, gpt2.time);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(eNTokenizer.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    eNTokenizer.loadData(randomInts[i2], tensor, tensor2);
                    Tensor forward = gpt2.forward(tensor, positions, triu);
                    this.loss = gpt2.loss(forward, tensor2, eNTokenizer.dictionary.get("<pad>").intValue());
                    this.lossDiff = gpt2.lossDiff(forward, tensor2, eNTokenizer.dictionary.get("<pad>").intValue());
                    this.network.back(this.lossDiff);
                    this.network.clipGradNorm(1.0f);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor2, forward.number / this.batchSize, this.batchSize, eNTokenizer.vocab, eNTokenizer.dictionary.get("<pad>").intValue()) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainGPT(CNChatTokenizer cNChatTokenizer) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = cNChatTokenizer.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            GPT gpt = (GPT) this.network;
            Tensor tensor = new Tensor(this.batchSize * gpt.time, 1, 1, gpt.vocab_size, true);
            Tensor tensor2 = new Tensor(this.batchSize * gpt.time, 1, 1, gpt.vocab_size, true);
            Tensor triu = ENTokenizer.triu(this.batchSize, gpt.head_num, gpt.time, gpt.time, 1.0f);
            Tensor positions = ENTokenizer.getPositions(this.batchSize, gpt.time);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(cNChatTokenizer.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    cNChatTokenizer.loadData(randomInts[i2], tensor, tensor2);
                    Tensor forward = gpt.forward(tensor, positions, triu);
                    this.loss = gpt.loss(forward, tensor2, cNChatTokenizer.dictionary.get("<pad>").intValue());
                    this.lossDiff = gpt.lossDiff(forward, tensor2, cNChatTokenizer.dictionary.get("<pad>").intValue());
                    this.network.back(this.lossDiff);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor2, forward.number / this.batchSize, this.batchSize, cNChatTokenizer.vocab, cNChatTokenizer.dictionary.get("<pad>").intValue()) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainGPT2(CNChatTokenizer cNChatTokenizer) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = cNChatTokenizer.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            GPT2 gpt2 = (GPT2) this.network;
            Tensor tensor = new Tensor(this.batchSize * gpt2.time, 1, 1, gpt2.vocabSize, true);
            Tensor tensor2 = new Tensor(this.batchSize * gpt2.time, 1, 1, gpt2.vocabSize, true);
            Tensor triu = ENTokenizer.triu(this.batchSize, gpt2.headNum, gpt2.time, gpt2.time, 1.0f);
            Tensor positions = ENTokenizer.getPositions(this.batchSize, gpt2.time);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(cNChatTokenizer.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    cNChatTokenizer.loadData(randomInts[i2], tensor, tensor2);
                    Tensor forward = gpt2.forward(tensor, positions, triu);
                    this.loss = gpt2.loss(forward, tensor2, cNChatTokenizer.dictionary.get("<pad>").intValue());
                    this.lossDiff = gpt2.lossDiff(forward, tensor2, cNChatTokenizer.dictionary.get("<pad>").intValue());
                    this.network.back(this.lossDiff);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor2, forward.number / this.batchSize, this.batchSize, cNChatTokenizer.vocab, cNChatTokenizer.dictionary.get("<pad>").intValue()) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainNanoGPT(ENTokenizer eNTokenizer) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = eNTokenizer.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            NanoGPT nanoGPT = (NanoGPT) this.network;
            Tensor tensor = new Tensor(this.batchSize * nanoGPT.time, 1, 1, nanoGPT.vocabSize, true);
            Tensor tensor2 = new Tensor(this.batchSize * nanoGPT.time, 1, 1, nanoGPT.vocabSize, true);
            Tensor triu = ENTokenizer.triu(this.batchSize, nanoGPT.headNum, nanoGPT.time, nanoGPT.time, 1.0f);
            Tensor positions = ENTokenizer.getPositions(this.batchSize, nanoGPT.time);
            int[][] iArr = new int[new BigDecimal(eNTokenizer.number).divide(new BigDecimal(this.batchSize), 0, 0).intValue()][this.batchSize];
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(eNTokenizer.number, this.batchSize, iArr, arrayList);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    eNTokenizer.loadData(randomInts[i2], tensor, tensor2);
                    Tensor forward = nanoGPT.forward(tensor, positions, triu);
                    this.loss = nanoGPT.loss(forward, tensor2, eNTokenizer.dictionary.get("<pad>").intValue());
                    this.lossDiff = nanoGPT.lossDiff(forward, tensor2, eNTokenizer.dictionary.get("<pad>").intValue());
                    this.network.back(this.lossDiff);
                    this.network.update();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor2, forward.number / this.batchSize, this.batchSize, eNTokenizer.vocab, eNTokenizer.dictionary.get("<pad>").intValue()) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainNanoGPT(CNChatTokenizer cNChatTokenizer) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = cNChatTokenizer.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            NanoGPT nanoGPT = (NanoGPT) this.network;
            Tensor tensor = new Tensor(this.batchSize * nanoGPT.time, 1, 1, 1, true);
            Tensor tensor2 = new Tensor(this.batchSize * nanoGPT.time, 1, 1, nanoGPT.vocabSize, true);
            Tensor positions = CNChatTokenizer.getPositions(this.batchSize, nanoGPT.time);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(cNChatTokenizer.trainData.size(), this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    cNChatTokenizer.loadTrainData(randomInts[i2], tensor, tensor2);
                    Tensor forward = nanoGPT.forward(tensor, positions);
                    this.loss = nanoGPT.loss(forward, tensor2, cNChatTokenizer.dictionary.get("<pad>").intValue());
                    this.lossDiff = nanoGPT.lossDiff(forward, tensor2, cNChatTokenizer.dictionary.get("<pad>").intValue());
                    this.network.back(this.lossDiff);
                    this.network.update();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor2, forward.number / this.batchSize, this.batchSize, cNChatTokenizer.vocab, cNChatTokenizer.dictionary.get("<pad>").intValue()) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                    if (i2 != 0 && i2 % 200 == 0) {
                        vail_chat(nanoGPT, tensor, forward, tensor2, positions, cNChatTokenizer);
                        nanoGPT.RUN_MODEL = RunModel.TRAIN;
                    }
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainNanoGPT(CNChatTokenizer2 cNChatTokenizer2) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = cNChatTokenizer2.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            NanoGPT nanoGPT = (NanoGPT) this.network;
            Tensor tensor = new Tensor(this.batchSize * nanoGPT.time, 1, 1, 1, true);
            Tensor tensor2 = new Tensor(this.batchSize * nanoGPT.time, 1, 1, nanoGPT.vocabSize, true);
            Tensor positions = CNChatTokenizer.getPositions(this.batchSize, nanoGPT.time);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(cNChatTokenizer2.trainData.size(), this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    cNChatTokenizer2.loadTrainData(randomInts[i2], tensor, tensor2);
                    Tensor forward = nanoGPT.forward(tensor, positions);
                    this.loss = nanoGPT.loss(forward, tensor2, cNChatTokenizer2.tokenizer.specials.get("<pad>").intValue());
                    this.lossDiff = nanoGPT.lossDiff(forward, tensor2, cNChatTokenizer2.tokenizer.specials.get("<pad>").intValue());
                    this.network.back(this.lossDiff);
                    this.network.update();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor2, forward.number / this.batchSize, this.batchSize, cNChatTokenizer2.tokenizer, cNChatTokenizer2.tokenizer.specials.get("<pad>").intValue()) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainNanoGPT_GEN(CNTokenizer cNTokenizer) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = cNTokenizer.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            NanoGPT nanoGPT = (NanoGPT) this.network;
            Tensor tensor = new Tensor(this.batchSize * nanoGPT.time, 1, 1, 1, true);
            Tensor tensor2 = new Tensor(this.batchSize * nanoGPT.time, 1, 1, nanoGPT.vocabSize, true);
            Tensor triu = CNChatTokenizer.triu(this.batchSize, nanoGPT.headNum, nanoGPT.time, nanoGPT.time, 1.0f);
            Tensor positions = CNChatTokenizer.getPositions(this.batchSize, nanoGPT.time);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(cNTokenizer.trainData.length - nanoGPT.time, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    cNTokenizer.loadData(randomInts[i2], tensor, tensor2);
                    Tensor forward = nanoGPT.forward(tensor, positions, triu);
                    this.loss = nanoGPT.loss(forward, tensor2);
                    this.lossDiff = nanoGPT.lossDiff(forward, tensor2);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor2, forward.number / this.batchSize, this.batchSize, cNTokenizer.vocab) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                    if (i2 != 0 && i2 % 200 == 0) {
                        vail_gen(nanoGPT, tensor, forward, tensor2, triu, positions, cNTokenizer);
                        nanoGPT.RUN_MODEL = RunModel.TRAIN;
                    }
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.optimizer.Optimizer
    public void train(BaseData baseData, BaseData baseData2) {
    }

    @Override // com.omega.engine.optimizer.Optimizer
    public void train(BaseData baseData, BaseData baseData2, BaseData baseData3) {
    }

    public void showOutputAndLabel(IndexDataLoader indexDataLoader, Tensor tensor, Tensor tensor2, Tensor tensor3, int i) {
        String[] input2TXT = input2TXT(tensor, indexDataLoader, i);
        String[] output2TXT = output2TXT(tensor2, indexDataLoader, i);
        String[] output2TXT2 = output2TXT(tensor3, indexDataLoader, i);
        for (int i2 = 0; i2 < i; i2++) {
            System.out.println("input:" + input2TXT[i2]);
            System.out.println("output:" + output2TXT[i2]);
            System.out.println("label :" + output2TXT2[i2]);
        }
    }

    public static String[] output2TXT(Tensor tensor, IndexDataLoader indexDataLoader, int i) {
        String[] strArr = new String[i];
        for (int i2 = 0; i2 < i; i2++) {
            String str = "";
            for (int i3 = 0; i3 < indexDataLoader.max_ch - 1; i3++) {
                str = str + indexDataLoader.ch_dic[pickTopN(tensor.getByNumber((i3 * i) + i2), 1)];
            }
            strArr[i2] = str;
        }
        return strArr;
    }

    public static String[] output2TXT(Tensor tensor, IndexDataLoader indexDataLoader, int i, int i2) {
        String[] strArr = new String[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            String str = "";
            for (int i4 = 0; i4 < i; i4++) {
                str = str + indexDataLoader.ch_dic[pickTopN(tensor.getByNumber((i4 * i2) + i3), 1)];
            }
            strArr[i3] = str;
        }
        return strArr;
    }

    public static String[] input2TXT(Tensor tensor, IndexDataLoader indexDataLoader, int i) {
        String[] strArr = new String[i];
        for (int i2 = 0; i2 < i; i2++) {
            String str = "";
            for (int i3 = 0; i3 < indexDataLoader.max_en; i3++) {
                str = str + indexDataLoader.en_dic[pickTopN(tensor.getByNumber((i3 * i) + i2), 1)];
            }
            strArr[i2] = str;
        }
        return strArr;
    }

    public static int pickTopN(float[] fArr, int i) {
        float[] copyOf = Arrays.copyOf(fArr, fArr.length);
        Arrays.sort(copyOf);
        float[] copyOfRange = Arrays.copyOfRange(copyOf, copyOf.length - i, copyOf.length);
        float f = copyOfRange[RandomUtils.getRandomNumber(copyOfRange)];
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (f == fArr[i2]) {
                return i2;
            }
        }
        return 0;
    }

    public void gradClipping(Network network) {
        for (Layer layer : network.layerList) {
            if (layer.diffW != null) {
                GradClipping.gradClipping(layer.diffW, 1.0E-7f);
            }
            if (layer.diffB != null) {
                GradClipping.gradClipping(layer.diffB, 1.0E-7f);
            }
        }
    }

    public String predict(IndexDataLoader indexDataLoader, String str) {
        Seq2Seq seq2Seq = (Seq2Seq) this.network;
        Tensor[] encoder = seq2Seq.encoder(indexDataLoader.loadByTxt(str));
        Tensor tensor = encoder[1];
        Tensor tensor2 = encoder[2];
        float[] fArr = new float[indexDataLoader.ch_characters];
        fArr[indexDataLoader.ch_dictionary.get("<BOS>").intValue()] = 1.0f;
        Tensor tensor3 = new Tensor(1, 1, 1, indexDataLoader.ch_characters, fArr, true);
        String str2 = "";
        for (int i = 0; i < seq2Seq.de_time; i++) {
            Tensor decoder = seq2Seq.decoder(tensor, tensor2, tensor3);
            decoder.syncHost();
            String[] output2TXT = output2TXT(decoder, indexDataLoader, 1, 1);
            str2 = str2 + output2TXT[0];
            tensor3.clear();
            tensor3.data[indexDataLoader.ch_dictionary.get(output2TXT[0]).intValue()] = 1.0f;
            tensor3.hostToDevice();
        }
        System.out.println(str2);
        return str2;
    }

    public String predictRNN(IndexDataLoader indexDataLoader, String str) {
        Seq2SeqRNN seq2SeqRNN = (Seq2SeqRNN) this.network;
        Tensor encoder = seq2SeqRNN.encoder(indexDataLoader.loadByTxt(str));
        encoder.syncHost();
        Tensor tensor = new Tensor(1, 1, 1, encoder.width, encoder.getByNumber(encoder.number - 1), true);
        float[] fArr = new float[indexDataLoader.ch_characters];
        fArr[indexDataLoader.ch_dictionary.get("<BOS>").intValue()] = 1.0f;
        Tensor tensor2 = new Tensor(1, 1, 1, indexDataLoader.ch_characters, fArr, true);
        String str2 = "";
        for (int i = 0; i < seq2SeqRNN.de_time; i++) {
            Tensor decoder = seq2SeqRNN.decoder(tensor, tensor2);
            decoder.syncHost();
            String[] output2TXT = output2TXT(decoder, indexDataLoader, 1, 1);
            str2 = str2 + output2TXT[0];
            tensor2.clear();
            tensor2.data[indexDataLoader.ch_dictionary.get(output2TXT[0]).intValue()] = 1.0f;
            tensor2.hostToDevice();
        }
        System.out.println(str2);
        return str2;
    }

    public void vail_gen(NanoGPT nanoGPT, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, CNTokenizer cNTokenizer) {
        nanoGPT.RUN_MODEL = RunModel.TEST;
        int[][] randomInts = MathUtils.randomInts(cNTokenizer.vailData.length - nanoGPT.time, this.batchSize);
        for (int i = 0; i < randomInts.length && i <= 20; i++) {
            long nanoTime = System.nanoTime();
            this.loss.clear();
            cNTokenizer.loadDataVail(randomInts[i], tensor, tensor3);
            Tensor forward = nanoGPT.forward(tensor, tensor5, tensor4);
            this.loss = nanoGPT.loss(forward, tensor3);
            if (this.loss.isHasGPU()) {
                this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
            } else {
                this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
            }
            forward.syncHost();
            System.out.println("vail[" + this.trainIndex + "]{" + i + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor3, forward.number / this.batchSize, this.batchSize, cNTokenizer.vocab) + "%} vail_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        }
    }

    public void vail_chat(NanoGPT nanoGPT, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, CNChatTokenizer cNChatTokenizer) {
        nanoGPT.RUN_MODEL = RunModel.TEST;
        int[][] randomInts = MathUtils.randomInts(cNChatTokenizer.vailData.size() - nanoGPT.time, this.batchSize);
        for (int i = 0; i < randomInts.length && i <= 20; i++) {
            long nanoTime = System.nanoTime();
            this.loss.clear();
            cNChatTokenizer.loadVailData(randomInts[i], tensor, tensor3);
            Tensor forward = nanoGPT.forward(tensor, tensor4);
            this.loss = nanoGPT.loss(forward, tensor3, cNChatTokenizer.dictionary.get("<pad>").intValue());
            if (this.loss.isHasGPU()) {
                this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
            } else {
                this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
            }
            forward.syncHost();
            System.out.println("vail[" + this.trainIndex + "]{" + i + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracyBatchFisrt(tensor, forward, tensor3, forward.number / this.batchSize, this.batchSize, cNChatTokenizer.vocab, cNChatTokenizer.dictionary.get("<pad>").intValue()) + "%} vail_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        }
    }
}
