package com.omega.example.cnn.test;

import com.omega.common.utils.DataLoader;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.loss.SoftmaxWithCrossEntropyLoss;
import com.omega.engine.nn.data.DataSet;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.DropoutLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.PoolingLayer;
import com.omega.engine.nn.layer.SoftmaxWithCrossEntropyLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.network.CNN;
import com.omega.engine.nn.network.Network;
import com.omega.engine.optimizer.MBSGDOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.pooling.PoolingType;
import com.omega.engine.updater.UpdaterType;
import java.io.File;

/* loaded from: input_file:com/omega/example/cnn/test/CNNTest.class */
public class CNNTest {
    public void cnnNetwork_mnist() {
        try {
            try {
                String[] strArr = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"};
                File file = new File(CNNTest.class.getResource("/dataset/mnist/train-images.idx3-ubyte").toURI());
                File file2 = new File(CNNTest.class.getResource("/dataset/mnist/train-labels.idx1-ubyte").toURI());
                File file3 = new File(CNNTest.class.getResource("/dataset/mnist/t10k-images.idx3-ubyte").toURI());
                File file4 = new File(CNNTest.class.getResource("/dataset/mnist/t10k-labels.idx1-ubyte").toURI());
                DataSet loadDataByUByte = DataLoader.loadDataByUByte(file, file2, strArr, 1, 1, 784, true);
                DataSet loadDataByUByte2 = DataLoader.loadDataByUByte(file3, file4, strArr, 1, 1, 784, true);
                CNN cnn = new CNN(new SoftmaxWithCrossEntropyLoss(), UpdaterType.adamw);
                cnn.learnRate = 0.001f;
                InputLayer inputLayer = new InputLayer(1, 1, 784);
                ConvolutionLayer convolutionLayer = new ConvolutionLayer(1, 6, 28, 28, 5, 5, 2, 1, false);
                new BNLayer();
                ReluLayer reluLayer = new ReluLayer();
                PoolingLayer poolingLayer = new PoolingLayer(convolutionLayer.oChannel, convolutionLayer.oWidth, convolutionLayer.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                ConvolutionLayer convolutionLayer2 = new ConvolutionLayer(poolingLayer.oChannel, 12, poolingLayer.oWidth, poolingLayer.oHeight, 5, 5, 0, 1, false);
                new BNLayer();
                ReluLayer reluLayer2 = new ReluLayer();
                new DropoutLayer(0.5f);
                PoolingLayer poolingLayer2 = new PoolingLayer(convolutionLayer2.oChannel, convolutionLayer2.oWidth, convolutionLayer2.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                int i = poolingLayer2.oChannel * poolingLayer2.oWidth * poolingLayer2.oHeight;
                int sqrt = (int) (Math.sqrt(i + 10) + 10.0d);
                FullyLayer fullyLayer = new FullyLayer(i, sqrt, false);
                new BNLayer();
                ReluLayer reluLayer3 = new ReluLayer();
                FullyLayer fullyLayer2 = new FullyLayer(sqrt, 10);
                SoftmaxWithCrossEntropyLayer softmaxWithCrossEntropyLayer = new SoftmaxWithCrossEntropyLayer(10);
                cnn.addLayer(inputLayer);
                cnn.addLayer(convolutionLayer);
                cnn.addLayer(reluLayer);
                cnn.addLayer(poolingLayer);
                cnn.addLayer(convolutionLayer2);
                cnn.addLayer(reluLayer2);
                cnn.addLayer(poolingLayer2);
                cnn.addLayer(fullyLayer);
                cnn.addLayer(reluLayer3);
                cnn.addLayer(fullyLayer2);
                cnn.addLayer(softmaxWithCrossEntropyLayer);
                MBSGDOptimizer mBSGDOptimizer = new MBSGDOptimizer((Network) cnn, 10, 1.0E-4f, 128, LearnRateUpdate.NONE, false);
                long currentTimeMillis = System.currentTimeMillis();
                mBSGDOptimizer.train(loadDataByUByte);
                mBSGDOptimizer.test(loadDataByUByte2);
                System.out.println(((System.currentTimeMillis() - currentTimeMillis) / 1000) + "s.");
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            } catch (Exception e2) {
                e2.printStackTrace();
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e3) {
                    e3.printStackTrace();
                }
            }
        } catch (Throwable th) {
            try {
                CUDAMemoryManager.freeAll();
            } catch (Exception e4) {
                e4.printStackTrace();
            }
            throw th;
        }
    }

    public void cnnNetwork_cifar10() {
        try {
            try {
                String[] strArr = {"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"};
                DataSet imagesToDataSetByBin = DataLoader.getImagesToDataSetByBin(new String[]{"H:/dataset/cifar-10/data_batch_1.bin", "H:/dataset/cifar-10/data_batch_2.bin", "H:/dataset/cifar-10/data_batch_3.bin", "H:/dataset/cifar-10/data_batch_4.bin", "H:/dataset/cifar-10/data_batch_5.bin"}, 10000, 3, 32, 32, 10, strArr, true);
                DataSet imagesToDataSetByBin2 = DataLoader.getImagesToDataSetByBin("H:/dataset/cifar-10/test_batch.bin", 10000, 3, 32, 32, 10, strArr, true);
                System.out.println("data is ready.");
                CNN cnn = new CNN(LossType.softmax_with_cross_entropy, UpdaterType.adam);
                cnn.learnRate = 0.01f;
                cnn.addLayer(new InputLayer(3, 32, 32));
                ConvolutionLayer convolutionLayer = new ConvolutionLayer(3, 16, 32, 32, 3, 3, 1, 1, false);
                cnn.addLayer(convolutionLayer);
                cnn.addLayer(new BNLayer());
                cnn.addLayer(new ReluLayer());
                PoolingLayer poolingLayer = new PoolingLayer(convolutionLayer.oChannel, convolutionLayer.oWidth, convolutionLayer.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                cnn.addLayer(poolingLayer);
                ConvolutionLayer convolutionLayer2 = new ConvolutionLayer(poolingLayer.oChannel, 32, poolingLayer.oWidth, poolingLayer.oHeight, 3, 3, 1, 1, false);
                cnn.addLayer(convolutionLayer2);
                cnn.addLayer(new BNLayer());
                cnn.addLayer(new ReluLayer());
                PoolingLayer poolingLayer2 = new PoolingLayer(convolutionLayer2.oChannel, convolutionLayer2.oWidth, convolutionLayer2.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                cnn.addLayer(poolingLayer2);
                ConvolutionLayer convolutionLayer3 = new ConvolutionLayer(poolingLayer2.oChannel, 64, poolingLayer2.oWidth, poolingLayer2.oHeight, 3, 3, 1, 1, false);
                cnn.addLayer(convolutionLayer3);
                cnn.addLayer(new BNLayer());
                cnn.addLayer(new ReluLayer());
                PoolingLayer poolingLayer3 = new PoolingLayer(convolutionLayer3.oChannel, convolutionLayer3.oWidth, convolutionLayer3.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                cnn.addLayer(poolingLayer3);
                FullyLayer fullyLayer = new FullyLayer(poolingLayer3.oChannel * poolingLayer3.oWidth * poolingLayer3.oHeight, 256, true);
                cnn.addLayer(fullyLayer);
                cnn.addLayer(new ReluLayer());
                cnn.addLayer(new DropoutLayer(0.5f));
                cnn.addLayer(new FullyLayer(fullyLayer.oWidth, 10, true));
                cnn.addLayer(new SoftmaxWithCrossEntropyLayer(10));
                MBSGDOptimizer mBSGDOptimizer = new MBSGDOptimizer((Network) cnn, 20, 0.001f, 128, LearnRateUpdate.CONSTANT, false);
                long currentTimeMillis = System.currentTimeMillis();
                mBSGDOptimizer.train(imagesToDataSetByBin);
                mBSGDOptimizer.test(imagesToDataSetByBin2);
                System.out.println(((System.currentTimeMillis() - currentTimeMillis) / 1000) + "s.");
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            } catch (Exception e2) {
                e2.printStackTrace();
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e3) {
                    e3.printStackTrace();
                }
            }
        } catch (Throwable th) {
            try {
                CUDAMemoryManager.freeAll();
            } catch (Exception e4) {
                e4.printStackTrace();
            }
            throw th;
        }
    }

    public static void main(String[] strArr) {
        try {
            CUDAModules.initContext();
            new CNNTest().cnnNetwork_cifar10();
        } finally {
            CUDAMemoryManager.free();
        }
    }
}
