package com.omega.example.resnet.test;

import com.omega.common.data.Tensor;
import com.omega.common.data.utils.DataTransforms;
import com.omega.common.utils.DataLoader;
import com.omega.common.utils.ImageUtils;
import com.omega.common.utils.LabelUtils;
import com.omega.common.utils.MathUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.loss.SoftmaxWithCrossEntropyLoss;
import com.omega.engine.nn.data.DataSet;
import com.omega.engine.nn.layer.AVGPoolingLayer;
import com.omega.engine.nn.layer.BasicBlockLayer;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.ParamsInit;
import com.omega.engine.nn.layer.PoolingLayer;
import com.omega.engine.nn.layer.SoftmaxWithCrossEntropyLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.network.CNN;
import com.omega.engine.nn.network.Network;
import com.omega.engine.optimizer.MBSGDOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.pooling.PoolingType;
import com.omega.engine.updater.UpdaterType;
import java.io.File;

/* loaded from: input_file:com/omega/example/resnet/test/ResnetTest.class */
public class ResnetTest {
    public void showImage() {
        DataSet imagesToDataSetByBin = DataLoader.getImagesToDataSetByBin("H:/dataset/cifar-10-binary.tar/cifar-10-binary/cifar-10-batches-bin/test_batch.bin", 10000, 3, 32, 32, 10, new String[]{"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"}, false);
        ImageUtils imageUtils = new ImageUtils();
        System.out.println(imagesToDataSetByBin.labels[10]);
        imageUtils.createRGBImage("H:/dataset\\r.png", "png", 32, 32, imagesToDataSetByBin.input.getByNumberAndChannel(10, 0), 2);
    }

    public void resnet18_mnist() {
        try {
            try {
                String[] strArr = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"};
                File file = new File(ResnetTest.class.getResource("/dataset/mnist/train-images.idx3-ubyte").toURI());
                File file2 = new File(ResnetTest.class.getResource("/dataset/mnist/train-labels.idx1-ubyte").toURI());
                File file3 = new File(ResnetTest.class.getResource("/dataset/mnist/t10k-images.idx3-ubyte").toURI());
                File file4 = new File(ResnetTest.class.getResource("/dataset/mnist/t10k-labels.idx1-ubyte").toURI());
                DataSet loadDataByUByte = DataLoader.loadDataByUByte(file, file2, strArr, 1, 1, 784, true);
                DataSet loadDataByUByte2 = DataLoader.loadDataByUByte(file3, file4, strArr, 1, 1, 784, true);
                CNN cnn = new CNN(new SoftmaxWithCrossEntropyLoss(), UpdaterType.adamw);
                cnn.learnRate = 0.001f;
                InputLayer inputLayer = new InputLayer(1, 1, 784);
                ConvolutionLayer convolutionLayer = new ConvolutionLayer(1, 64, 28, 28, 3, 3, 1, 1, false);
                BNLayer bNLayer = new BNLayer();
                ReluLayer reluLayer = new ReluLayer();
                PoolingLayer poolingLayer = new PoolingLayer(convolutionLayer.oChannel, convolutionLayer.oWidth, convolutionLayer.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                BasicBlockLayer basicBlockLayer = new BasicBlockLayer(poolingLayer.oChannel, 64, poolingLayer.oHeight, poolingLayer.oWidth, 1, cnn);
                ReluLayer reluLayer2 = new ReluLayer();
                BasicBlockLayer basicBlockLayer2 = new BasicBlockLayer(basicBlockLayer.oChannel, 64, basicBlockLayer.oHeight, basicBlockLayer.oWidth, 1, cnn);
                ReluLayer reluLayer3 = new ReluLayer();
                BasicBlockLayer basicBlockLayer3 = new BasicBlockLayer(basicBlockLayer2.oChannel, 128, basicBlockLayer2.oHeight, basicBlockLayer2.oWidth, 2, cnn);
                ReluLayer reluLayer4 = new ReluLayer();
                BasicBlockLayer basicBlockLayer4 = new BasicBlockLayer(basicBlockLayer3.oChannel, 128, basicBlockLayer3.oHeight, basicBlockLayer3.oWidth, 1, cnn);
                ReluLayer reluLayer5 = new ReluLayer();
                BasicBlockLayer basicBlockLayer5 = new BasicBlockLayer(basicBlockLayer4.oChannel, 256, basicBlockLayer4.oHeight, basicBlockLayer4.oWidth, 2, cnn);
                ReluLayer reluLayer6 = new ReluLayer();
                BasicBlockLayer basicBlockLayer6 = new BasicBlockLayer(basicBlockLayer5.oChannel, 256, basicBlockLayer5.oHeight, basicBlockLayer5.oWidth, 1, cnn);
                ReluLayer reluLayer7 = new ReluLayer();
                BasicBlockLayer basicBlockLayer7 = new BasicBlockLayer(basicBlockLayer6.oChannel, 512, basicBlockLayer6.oHeight, basicBlockLayer6.oWidth, 2, cnn);
                ReluLayer reluLayer8 = new ReluLayer();
                BasicBlockLayer basicBlockLayer8 = new BasicBlockLayer(basicBlockLayer7.oChannel, 512, basicBlockLayer7.oHeight, basicBlockLayer7.oWidth, 1, cnn);
                ReluLayer reluLayer9 = new ReluLayer();
                PoolingLayer poolingLayer2 = new PoolingLayer(basicBlockLayer8.oChannel, basicBlockLayer8.oWidth, basicBlockLayer8.oHeight, 4, 4, 4, PoolingType.MEAN_POOLING);
                FullyLayer fullyLayer = new FullyLayer(poolingLayer2.oChannel * poolingLayer2.oWidth * poolingLayer2.oHeight, 10, false);
                new BNLayer();
                SoftmaxWithCrossEntropyLayer softmaxWithCrossEntropyLayer = new SoftmaxWithCrossEntropyLayer(10);
                cnn.addLayer(inputLayer);
                cnn.addLayer(convolutionLayer);
                cnn.addLayer(bNLayer);
                cnn.addLayer(reluLayer);
                cnn.addLayer(poolingLayer);
                cnn.addLayer(basicBlockLayer);
                cnn.addLayer(reluLayer2);
                cnn.addLayer(basicBlockLayer2);
                cnn.addLayer(reluLayer3);
                cnn.addLayer(basicBlockLayer3);
                cnn.addLayer(reluLayer4);
                cnn.addLayer(basicBlockLayer4);
                cnn.addLayer(reluLayer5);
                cnn.addLayer(basicBlockLayer5);
                cnn.addLayer(reluLayer6);
                cnn.addLayer(basicBlockLayer6);
                cnn.addLayer(reluLayer7);
                cnn.addLayer(basicBlockLayer7);
                cnn.addLayer(reluLayer8);
                cnn.addLayer(basicBlockLayer8);
                cnn.addLayer(reluLayer9);
                cnn.addLayer(poolingLayer2);
                cnn.addLayer(fullyLayer);
                cnn.addLayer(softmaxWithCrossEntropyLayer);
                MBSGDOptimizer mBSGDOptimizer = new MBSGDOptimizer((Network) cnn, 20, 1.0E-4f, 128, LearnRateUpdate.NONE, false);
                long currentTimeMillis = System.currentTimeMillis();
                mBSGDOptimizer.train(loadDataByUByte);
                mBSGDOptimizer.test(loadDataByUByte2);
                System.out.println(((System.currentTimeMillis() - currentTimeMillis) / 1000) + "s.");
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            } catch (Exception e2) {
                e2.printStackTrace();
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e3) {
                    e3.printStackTrace();
                }
            }
        } catch (Throwable th) {
            try {
                CUDAMemoryManager.freeAll();
            } catch (Exception e4) {
                e4.printStackTrace();
            }
            throw th;
        }
    }

    public void resnet18_cifar10() {
        try {
            try {
                String[] strArr = {"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"};
                float[] fArr = {0.4914f, 0.4822f, 0.4465f};
                float[] fArr2 = {0.2023f, 0.1994f, 0.201f};
                DataSet imagesToDataSetByBin = DataLoader.getImagesToDataSetByBin(new String[]{"H:/dataset/cifar-10/data_batch_1.bin", "H:/dataset/cifar-10/data_batch_2.bin", "H:/dataset/cifar-10/data_batch_3.bin", "H:/dataset/cifar-10/data_batch_4.bin", "H:/dataset/cifar-10/data_batch_5.bin"}, 10000, 3, 32, 32, 10, strArr, true);
                DataSet imagesToDataSetByBin2 = DataLoader.getImagesToDataSetByBin("H:/dataset/cifar-10/test_batch.bin", 10000, 3, 32, 32, 10, strArr, true, fArr, fArr2);
                System.out.println("data is ready.");
                CNN cnn = new CNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw);
                cnn.CUDNN = true;
                cnn.learnRate = 0.01f;
                InputLayer inputLayer = new InputLayer(3, 32, 32);
                ConvolutionLayer convolutionLayer = new ConvolutionLayer(3, 64, 32, 32, 3, 3, 1, 1, false);
                convolutionLayer.paramsInit = ParamsInit.relu;
                BNLayer bNLayer = new BNLayer();
                ReluLayer reluLayer = new ReluLayer();
                BasicBlockLayer basicBlockLayer = new BasicBlockLayer(convolutionLayer.oChannel, 64, convolutionLayer.oHeight, convolutionLayer.oWidth, 1, cnn);
                ReluLayer reluLayer2 = new ReluLayer();
                BasicBlockLayer basicBlockLayer2 = new BasicBlockLayer(basicBlockLayer.oChannel, 64, basicBlockLayer.oHeight, basicBlockLayer.oWidth, 1, cnn);
                ReluLayer reluLayer3 = new ReluLayer();
                BasicBlockLayer basicBlockLayer3 = new BasicBlockLayer(basicBlockLayer2.oChannel, 128, basicBlockLayer2.oHeight, basicBlockLayer2.oWidth, 2, cnn);
                ReluLayer reluLayer4 = new ReluLayer();
                BasicBlockLayer basicBlockLayer4 = new BasicBlockLayer(basicBlockLayer3.oChannel, 128, basicBlockLayer3.oHeight, basicBlockLayer3.oWidth, 1, cnn);
                ReluLayer reluLayer5 = new ReluLayer();
                BasicBlockLayer basicBlockLayer5 = new BasicBlockLayer(basicBlockLayer4.oChannel, 256, basicBlockLayer4.oHeight, basicBlockLayer4.oWidth, 2, cnn);
                ReluLayer reluLayer6 = new ReluLayer();
                BasicBlockLayer basicBlockLayer6 = new BasicBlockLayer(basicBlockLayer5.oChannel, 256, basicBlockLayer5.oHeight, basicBlockLayer5.oWidth, 1, cnn);
                ReluLayer reluLayer7 = new ReluLayer();
                BasicBlockLayer basicBlockLayer7 = new BasicBlockLayer(basicBlockLayer6.oChannel, 512, basicBlockLayer6.oHeight, basicBlockLayer6.oWidth, 2, cnn);
                ReluLayer reluLayer8 = new ReluLayer();
                BasicBlockLayer basicBlockLayer8 = new BasicBlockLayer(basicBlockLayer7.oChannel, 512, basicBlockLayer7.oHeight, basicBlockLayer7.oWidth, 1, cnn);
                ReluLayer reluLayer9 = new ReluLayer();
                AVGPoolingLayer aVGPoolingLayer = new AVGPoolingLayer(basicBlockLayer8.oChannel, basicBlockLayer8.oWidth, basicBlockLayer8.oHeight);
                FullyLayer fullyLayer = new FullyLayer(aVGPoolingLayer.oChannel * aVGPoolingLayer.oWidth * aVGPoolingLayer.oHeight, 10);
                SoftmaxWithCrossEntropyLayer softmaxWithCrossEntropyLayer = new SoftmaxWithCrossEntropyLayer(10);
                cnn.addLayer(inputLayer);
                cnn.addLayer(convolutionLayer);
                cnn.addLayer(bNLayer);
                cnn.addLayer(reluLayer);
                cnn.addLayer(basicBlockLayer);
                cnn.addLayer(reluLayer2);
                cnn.addLayer(basicBlockLayer2);
                cnn.addLayer(reluLayer3);
                cnn.addLayer(basicBlockLayer3);
                cnn.addLayer(reluLayer4);
                cnn.addLayer(basicBlockLayer4);
                cnn.addLayer(reluLayer5);
                cnn.addLayer(basicBlockLayer5);
                cnn.addLayer(reluLayer6);
                cnn.addLayer(basicBlockLayer6);
                cnn.addLayer(reluLayer7);
                cnn.addLayer(basicBlockLayer7);
                cnn.addLayer(reluLayer8);
                cnn.addLayer(basicBlockLayer8);
                cnn.addLayer(reluLayer9);
                cnn.addLayer(aVGPoolingLayer);
                cnn.addLayer(fullyLayer);
                cnn.addLayer(softmaxWithCrossEntropyLayer);
                MBSGDOptimizer mBSGDOptimizer = new MBSGDOptimizer((Network) cnn, 500, 1.0E-4f, 128, LearnRateUpdate.GD_GECAY, false);
                long currentTimeMillis = System.currentTimeMillis();
                mBSGDOptimizer.train(imagesToDataSetByBin, imagesToDataSetByBin2, fArr, fArr2);
                mBSGDOptimizer.test(imagesToDataSetByBin2);
                System.out.println(((System.currentTimeMillis() - currentTimeMillis) / 1000) + "s.");
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            } catch (Exception e2) {
                e2.printStackTrace();
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e3) {
                    e3.printStackTrace();
                }
            }
        } catch (Throwable th) {
            try {
                CUDAMemoryManager.freeAll();
            } catch (Exception e4) {
                e4.printStackTrace();
            }
            throw th;
        }
    }

    public static void getImages() {
        try {
            String[] strArr = {"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"};
            DataSet imagesToDataSetByBin = DataLoader.getImagesToDataSetByBin(new String[]{"H:/dataset/cifar-10/data_batch_1.bin", "H:/dataset/cifar-10/data_batch_2.bin", "H:/dataset/cifar-10/data_batch_3.bin", "H:/dataset/cifar-10/data_batch_4.bin", "H:/dataset/cifar-10/data_batch_5.bin"}, 10000, 3, 32, 32, 10, strArr, false);
            DataSet imagesToDataSetByBin2 = DataLoader.getImagesToDataSetByBin("H:/dataset/cifar-10/test_batch.bin", 10000, 3, 32, 32, 10, strArr, false);
            Tensor tensor = new Tensor(128, 3, 32, 32);
            Tensor tensor2 = new Tensor(128, 1, 1, imagesToDataSetByBin2.labelSize);
            DataTransforms.randomCrop(imagesToDataSetByBin.input, 32, 32, 4);
            DataTransforms.randomHorizontalFilp(imagesToDataSetByBin.input);
            ImageUtils imageUtils = new ImageUtils();
            int[][] randomInts = MathUtils.randomInts(imagesToDataSetByBin2.number, 128);
            for (int i = 0; i < randomInts.length; i++) {
                String str = "H:/testImages/" + i + "/";
                File file = new File(str);
                if (!file.exists()) {
                    file.mkdir();
                }
                imagesToDataSetByBin2.getRandomData(randomInts[i], tensor, tensor2);
                for (int i2 = 0; i2 < tensor.number; i2++) {
                    imageUtils.createImage(i2, tensor.getByNumber(i2), LabelUtils.vectorTolabel(tensor2.getByNumber(i2), strArr), 32, 32, str, "png");
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        try {
            CUDAModules.initContext();
            new ResnetTest().resnet18_cifar10();
        } finally {
            CUDAMemoryManager.free();
        }
    }
}
