package com.omega.engine.optimizer;

import com.omega.common.data.Tensor;
import com.omega.common.data.utils.DataTransforms;
import com.omega.common.utils.MathUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.check.BaseCheck;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.nn.data.BaseData;
import com.omega.engine.nn.grad.GradClipping;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.OutputsNetwork;
import com.omega.engine.nn.network.RunModel;
import com.omega.engine.nn.network.Yolo;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.example.rnn.data.OneHotDataLoader;
import com.omega.example.rnn.data.RNNDataLoader;
import com.omega.example.yolo.data.BaseDataLoader;
import com.omega.example.yolo.data.DetectionDataLoader;
import com.omega.example.yolo.utils.YoloLabelUtils;
import java.util.Arrays;
import jcuda.driver.JCudaDriver;

/* loaded from: input_file:com/omega/engine/optimizer/MBSGDOptimizer.class */
public class MBSGDOptimizer extends Optimizer {
    private YoloLabelUtils u;

    public YoloLabelUtils dataEnhanceInstance() {
        if (this.u == null) {
            this.u = new YoloLabelUtils(1, 4);
        }
        return this.u;
    }

    public MBSGDOptimizer(Network network, int i, float f, int i2, boolean z) throws Exception {
        super(network, i2, i, f, z);
        this.batchSize = i2;
        this.loss = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.lossDiff = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
    }

    public MBSGDOptimizer(String str, Network network, int i, float f, int i2, boolean z) throws Exception {
        super(network, i2, i, f, z);
        setSid(str);
        this.batchSize = i2;
        this.loss = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.lossDiff = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
    }

    public MBSGDOptimizer(Network network, int i, float f, int i2, LearnRateUpdate learnRateUpdate, boolean z) throws Exception {
        super(network, i2, i, f, z);
        this.batchSize = i2;
        this.loss = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.lossDiff = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.learnRateUpdate = learnRateUpdate;
    }

    public MBSGDOptimizer(Network network, int i, float f, int i2, LearnRateUpdate learnRateUpdate, boolean z, BaseCheck baseCheck) throws Exception {
        super(network, i2, i, f, z);
        this.batchSize = i2;
        this.loss = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.lossDiff = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.learnRateUpdate = learnRateUpdate;
        this.check = baseCheck;
    }

    public MBSGDOptimizer(String str, Network network, int i, float f, int i2, LearnRateUpdate learnRateUpdate, boolean z) throws Exception {
        super(network, i2, i, f, z);
        setSid(str);
        this.batchSize = i2;
        this.loss = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.lossDiff = new Tensor(i2, this.network.oChannel, this.network.oHeight, this.network.oWidth);
        this.learnRateUpdate = learnRateUpdate;
    }

    @Override // com.omega.engine.optimizer.Optimizer
    public void train(BaseData baseData) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = baseData.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseData.labelSize, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(baseData.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    baseData.getRandomData(randomInts[i2], tensor, tensor2);
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    Tensor forward = this.network.forward(tensor);
                    this.loss = this.network.loss(forward, tensor2);
                    this.lossDiff = this.network.lossDiff(forward, tensor2);
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / this.batchSize;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / this.batchSize;
                    }
                    this.lossDiff.hostToDevice();
                    this.network.back(this.lossDiff);
                    this.network.update();
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracy(forward, tensor2, baseData.labelSet) + "%} currentError:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.optimizer.Optimizer
    public void train(BaseData baseData, BaseData baseData2) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = baseData.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseData.labelSize, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(baseData.number, this.batchSize);
                this.network.RUN_MODEL = RunModel.TRAIN;
                for (int i2 = 0; i2 < randomInts.length && Math.abs(this.currentError) > this.error; i2++) {
                    long nanoTime = System.nanoTime();
                    baseData.getRandomData(randomInts[i2], tensor, tensor2);
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    Tensor forward = this.network.forward(tensor);
                    this.loss = this.network.loss(forward, tensor2);
                    this.lossDiff = this.network.lossDiff(forward, tensor2);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    forward.syncHost();
                    float accuracy = accuracy(forward, tensor2, baseData.labelSet);
                    this.currentError = MatrixOperation.sum(this.loss.syncHost()) / this.batchSize;
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracy + "%} currentError:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
                test(baseData2, this.batchSize);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.optimizer.Optimizer
    public void train(BaseData baseData, BaseData baseData2, BaseData baseData3) {
    }

    public void train(BaseData baseData, BaseData baseData2, float[] fArr, float[] fArr2) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = baseData.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseData.labelSize, true);
            Tensor tensor3 = new Tensor(baseData.number, baseData.channel, baseData.height, baseData.width);
            Tensor tensor4 = new Tensor(this.batchSize, baseData2.channel, baseData2.height, baseData2.width, true);
            Tensor tensor5 = new Tensor(this.batchSize, 1, 1, baseData2.labelSize, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                transforms(baseData.input, tensor3, fArr, fArr2);
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(baseData.number, this.batchSize);
                this.network.RUN_MODEL = RunModel.TRAIN;
                float f = 0.0f;
                for (int i2 = 0; i2 < randomInts.length; i2++) {
                    long nanoTime = System.nanoTime();
                    if (Math.abs(this.currentError) <= this.error) {
                        break;
                    }
                    baseData.randomData(randomInts[i2], tensor3.data, tensor, tensor2);
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    Tensor forward = this.network.forward(tensor);
                    this.loss = this.network.loss(forward, tensor2);
                    this.lossDiff = this.network.lossDiff(forward, tensor2);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    forward.syncHost();
                    float accuracy = accuracy(forward, tensor2, baseData.labelSet);
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / this.batchSize;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / this.batchSize;
                    }
                    f += this.currentError;
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracy + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                System.out.println("training[" + this.trainIndex + "] train loss:{" + (f / randomInts.length) + "} ");
                updateLR(testAndLoss(baseData2, tensor4, tensor5, this.batchSize));
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void train(BaseDataLoader baseDataLoader) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = baseDataLoader.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseDataLoader.labelSize, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] shuffle = baseDataLoader.shuffle();
                this.network.RUN_MODEL = RunModel.TRAIN;
                float f = 0.0f;
                for (int i2 = 0; i2 < shuffle.length; i2++) {
                    long nanoTime = System.nanoTime();
                    if (Math.abs(this.currentError) <= this.error) {
                        break;
                    }
                    baseDataLoader.loadData(shuffle[i2], tensor, tensor2);
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    Tensor forward = this.network.forward(tensor);
                    this.loss = this.network.loss(forward, tensor2);
                    this.lossDiff = this.network.lossDiff(forward, tensor2);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / this.batchSize;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / this.batchSize;
                    }
                    f += this.currentError;
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + this.error + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                System.out.println("training[" + this.trainIndex + "] train loss:{" + (f / shuffle.length) + "} ");
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void train(BaseDataLoader baseDataLoader, BaseDataLoader baseDataLoader2, BaseCheck baseCheck) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = baseDataLoader.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor initLabelTensor = baseDataLoader.initLabelTensor();
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.network.RUN_MODEL = RunModel.TRAIN;
                this.trainIndex = i + 1;
                int[][] shuffle = baseDataLoader.shuffle();
                for (int i2 = 0; i2 < shuffle.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    baseDataLoader.loadData(shuffle[i2], tensor, initLabelTensor);
                    Tensor forward = this.network.forward(tensor);
                    Tensor loss = this.network.loss(forward, initLabelTensor);
                    this.lossDiff = this.network.lossDiff(forward, initLabelTensor);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    if (loss.isHasGPU()) {
                        loss.syncHost();
                    }
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") (loss:" + loss.getByIndex(0, 0, 0, 0) + ") (accuracy:" + ((baseCheck.check(forward, initLabelTensor, baseDataLoader.labelSet, false) / this.batchSize) * 100.0f) + "%) [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
                if (this.trainIndex % 10 == 0) {
                    System.out.println("----------------testing start----------------");
                    testAndLoss(baseDataLoader2, tensor, initLabelTensor, this.batchSize, baseCheck);
                    System.out.println("----------------testing finish---------------");
                }
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainObjectRecognition(BaseData baseData) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = baseData.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseData.labelSize, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(baseData.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    baseData.getRandomData(randomInts[i2], tensor, tensor2);
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    Tensor forward = this.network.forward(tensor);
                    this.loss = this.network.loss(forward, tensor2);
                    this.lossDiff = this.network.lossDiff(forward, tensor2);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / this.batchSize;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / this.batchSize;
                    }
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainObjectRecognition(BaseData baseData, BaseData baseData2) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = baseData.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseData.labelSize, true);
            Tensor tensor3 = new Tensor(this.batchSize, baseData2.channel, baseData2.height, baseData2.width, true);
            Tensor tensor4 = new Tensor(this.batchSize, 1, 1, baseData2.labelSize, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.network.RUN_MODEL = RunModel.TRAIN;
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(baseData.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    baseData.getRandomData(randomInts[i2], tensor, tensor2);
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    Tensor forward = this.network.forward(tensor);
                    this.loss = this.network.loss(forward, tensor2);
                    this.lossDiff = this.network.lossDiff(forward, tensor2);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / this.batchSize;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / this.batchSize;
                    }
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
                if (this.trainIndex % 100 == 0) {
                    System.out.println("----------------testing start----------------");
                    testObjectRecognition(baseData2, tensor3, tensor4, this.batchSize);
                    System.out.println("----------------testing finish---------------");
                }
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainObjectRecognition(BaseData baseData, BaseData baseData2, boolean z) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = baseData.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseData.labelSize, true);
            Tensor tensor3 = new Tensor(this.batchSize, baseData2.channel, baseData2.height, baseData2.width, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.network.RUN_MODEL = RunModel.TRAIN;
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(baseData.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    baseData.getRandomData(randomInts[i2], tensor, tensor2);
                    if (z) {
                        dataEnhanceInstance().transforms(tensor, tensor2);
                        YoloLabelUtils.formatToYolo(tensor2, tensor.height, tensor.width);
                    }
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    Tensor forward = this.network.forward(tensor);
                    this.loss = this.network.loss(forward, tensor2);
                    this.lossDiff = this.network.lossDiff(forward, tensor2);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / this.batchSize;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / this.batchSize;
                    }
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
                if (this.trainIndex % 100 == 0) {
                    System.out.println("----------------testing start----------------");
                    testObjectRecognition(baseData2, tensor3, tensor2, this.batchSize);
                    System.out.println("----------------testing finish---------------");
                }
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainObjectRecognition(DetectionDataLoader detectionDataLoader, DetectionDataLoader detectionDataLoader2) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = detectionDataLoader.number;
            Yolo yolo = (Yolo) this.network;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor initLabelTensor = detectionDataLoader.initLabelTensor();
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.network.RUN_MODEL = RunModel.TRAIN;
                this.trainIndex = i + 1;
                int[][] shuffle = detectionDataLoader.shuffle();
                for (int i2 = 0; i2 < shuffle.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    detectionDataLoader.loadData(shuffle[i2], tensor, initLabelTensor);
                    Tensor forward = yolo.forward(tensor);
                    this.network.loss(forward, initLabelTensor);
                    this.lossDiff = yolo.lossDiff(forward, initLabelTensor);
                    yolo.back(this.lossDiff);
                    this.network.update();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
                if (this.trainIndex % 100 == 0) {
                    System.out.println("----------------testing start----------------");
                    testObjectRecognition(detectionDataLoader2, tensor, initLabelTensor, this.batchSize);
                    System.out.println("----------------testing finish---------------");
                }
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainObjectRecognitionOutputs(BaseData baseData, BaseData baseData2, boolean z) {
        try {
            CUDAModules.initCUDAFunctions();
            OutputsNetwork outputsNetwork = (OutputsNetwork) this.network;
            this.dataSize = baseData.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseData.labelSize);
            Tensor tensor3 = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.network.RUN_MODEL = RunModel.TRAIN;
                this.trainIndex = i + 1;
                int[][] randomInts = MathUtils.randomInts(baseData.number, this.batchSize);
                for (int i2 = 0; i2 < randomInts.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    baseData.getRandomData(randomInts[i2], tensor, tensor2);
                    if (z) {
                        dataEnhanceInstance().transforms(tensor, tensor2);
                        YoloLabelUtils.formatToYoloV3(tensor2, tensor.height, tensor.width);
                    }
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    outputsNetwork.forward(tensor);
                    outputsNetwork.loss(tensor2);
                    outputsNetwork.back(outputsNetwork.lossDiff(tensor2));
                    this.network.update();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
                if (this.trainIndex % 100 == 0) {
                    System.out.println("----------------testing start----------------");
                    testObjectRecognitionOutputs(baseData2, tensor3, tensor2, this.batchSize);
                    System.out.println("----------------testing finish---------------");
                }
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainObjectRecognitionOutputs(BaseDataLoader baseDataLoader, BaseDataLoader baseDataLoader2, boolean z) {
        try {
            CUDAModules.initCUDAFunctions();
            OutputsNetwork outputsNetwork = (OutputsNetwork) this.network;
            this.dataSize = baseDataLoader.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor2 = new Tensor(this.batchSize, 1, 1, baseDataLoader.labelSize, true);
            Tensor tensor3 = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor tensor4 = new Tensor(this.batchSize, 1, 1, baseDataLoader2.labelSize, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.network.RUN_MODEL = RunModel.TRAIN;
                this.trainIndex = i + 1;
                int[][] shuffle = baseDataLoader.shuffle();
                for (int i2 = 0; i2 < shuffle.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    baseDataLoader.loadData(shuffle[i2], tensor, tensor2);
                    if (z) {
                        dataEnhanceInstance().transforms(tensor, tensor2);
                        YoloLabelUtils.formatToYolo(tensor2, tensor.height, tensor.width);
                    }
                    tensor.hostToDevice();
                    tensor2.hostToDevice();
                    outputsNetwork.forward(tensor);
                    outputsNetwork.loss(tensor2);
                    System.out.println("in--------------->");
                    outputsNetwork.back(outputsNetwork.lossDiff(tensor2));
                    this.network.update();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
                if (this.trainIndex % 100 == 0) {
                    System.out.println("----------------testing start----------------");
                    testObjectRecognitionOutputs(baseDataLoader2, tensor3, tensor4, this.batchSize);
                    System.out.println("----------------testing finish---------------");
                }
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainObjectRecognitionOutputs(DetectionDataLoader detectionDataLoader, DetectionDataLoader detectionDataLoader2) {
        try {
            CUDAModules.initCUDAFunctions();
            OutputsNetwork outputsNetwork = (OutputsNetwork) this.network;
            this.dataSize = detectionDataLoader.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor initLabelTensor = detectionDataLoader.initLabelTensor();
            Tensor tensor2 = new Tensor(this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                if (this.trainIndex == 2) {
                    this.network.unfreeze();
                }
                this.network.RUN_MODEL = RunModel.TRAIN;
                this.trainIndex = i + 1;
                int[][] shuffle = detectionDataLoader.shuffle();
                for (int i2 = 0; i2 < shuffle.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    detectionDataLoader.loadData(shuffle[i2], tensor, initLabelTensor);
                    outputsNetwork.forward(tensor);
                    outputsNetwork.loss(initLabelTensor);
                    outputsNetwork.back(outputsNetwork.lossDiff(initLabelTensor));
                    this.network.update();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
                if (this.trainIndex % 100 == 0) {
                    System.out.println("----------------testing start----------------");
                    testObjectRecognitionOutputs(detectionDataLoader2, tensor2, initLabelTensor, this.batchSize);
                    System.out.println("----------------testing finish---------------");
                }
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void testRNN(Tensor tensor) {
        try {
            CUDAModules.initCUDAFunctions();
            Tensor forward = this.network.forward(tensor);
            forward.showDM();
            this.lossDiff = new Tensor(forward.number, forward.channel, forward.height, forward.width, MatrixUtils.one(forward.dataLength), true);
            this.network.back(this.lossDiff);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void trainRNN(RNNDataLoader rNNDataLoader) {
        try {
            CUDAModules.initCUDAFunctions();
            this.dataSize = rNNDataLoader.number;
            if (isWarmUp()) {
                this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            }
            Tensor tensor = new Tensor(rNNDataLoader.time * this.batchSize, this.network.channel, this.network.height, this.network.width, true);
            Tensor initLabelTensor = rNNDataLoader.initLabelTensor();
            for (int i = 0; i < this.trainTime && this.trainIndex < this.minTrainTime; i++) {
                this.network.RUN_MODEL = RunModel.TRAIN;
                this.trainIndex = i + 1;
                int[][] shuffle = rNNDataLoader.shuffle();
                for (int i2 = 0; i2 < shuffle.length; i2++) {
                    long nanoTime = System.nanoTime();
                    this.loss.clear();
                    this.lossDiff.clear();
                    rNNDataLoader.loadData(shuffle[i2], tensor, initLabelTensor);
                    Tensor forward = this.network.forward(tensor);
                    this.loss = this.network.loss(forward, initLabelTensor);
                    this.lossDiff = this.network.lossDiff(forward, initLabelTensor);
                    this.network.back(this.lossDiff);
                    this.network.update();
                    JCudaDriver.cuCtxSynchronize();
                    if (this.loss.isHasGPU()) {
                        this.currentError = MatrixOperation.sum(this.loss.syncHost()) / tensor.number;
                    } else {
                        this.currentError = MatrixOperation.sum(this.loss.data) / tensor.number;
                    }
                    forward.syncHost();
                    System.out.println("training[" + this.trainIndex + "]{" + i2 + "} (lr:" + this.network.learnRate + ") accuracy:{" + accuracy(forward, initLabelTensor) + "%} train_loss:" + this.currentError + " [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
                    this.batchIndex++;
                }
                updateLR(this.lr_step);
            }
            System.out.println("training finish. [" + this.trainIndex + "] finalError:" + this.currentError);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static String output2TXT(Tensor tensor, RNNDataLoader rNNDataLoader) {
        String str = "";
        OneHotDataLoader oneHotDataLoader = (OneHotDataLoader) rNNDataLoader;
        for (int i = 0; i < tensor.number; i++) {
            str = str + oneHotDataLoader.dictionaryData[pickTopN(tensor.getByNumber(i), 1)].charValue();
        }
        return str;
    }

    public static int pickTopN(float[] fArr, int i) {
        float[] copyOf = Arrays.copyOf(fArr, fArr.length);
        Arrays.sort(copyOf);
        float[] copyOfRange = Arrays.copyOfRange(copyOf, copyOf.length - i, copyOf.length);
        float f = copyOfRange[RandomUtils.getRandomNumber(copyOfRange)];
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (f == fArr[i2]) {
                return i2;
            }
        }
        return 0;
    }

    public void gradClipping(Network network) {
        for (Layer layer : network.layerList) {
            if (layer.diffW != null) {
                GradClipping.gradClipping(layer.diffW, 1.0E-7f);
            }
            if (layer.diffB != null) {
                GradClipping.gradClipping(layer.diffB, 1.0E-7f);
            }
        }
    }

    public void transforms(Tensor tensor, Tensor tensor2, float[] fArr, float[] fArr2) {
        DataTransforms.randomCrop(tensor, tensor2, 32, 32, 4);
        DataTransforms.randomHorizontalFilp(tensor2, tensor2);
        DataTransforms.normalize(tensor2, tensor2, fArr, fArr2);
        DataTransforms.cutout(tensor2, tensor2, 16, 1);
        System.out.println("data transform finish.");
    }

    public void transforms2(Tensor tensor, Tensor tensor2, float[] fArr, float[] fArr2) {
        DataTransforms.normalize(tensor, tensor2, fArr, fArr2);
        System.out.println("data transform finish.");
    }
}
