package com.omega.example.vggnet.test;

import com.omega.common.utils.DataLoader;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.data.DataSet;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.PoolingLayer;
import com.omega.engine.nn.layer.SoftmaxWithCrossEntropyLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.network.CNN;
import com.omega.engine.nn.network.Network;
import com.omega.engine.optimizer.MBSGDOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.pooling.PoolingType;
import com.omega.engine.updater.UpdaterType;

/* loaded from: input_file:com/omega/example/vggnet/test/vggnetTest.class */
public class vggnetTest {
    public void vgg16_cifar10() {
        try {
            try {
                String[] strArr = {"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"};
                DataSet imagesToDataSetByBin = DataLoader.getImagesToDataSetByBin(new String[]{"H:/dataset/cifar-10/data_batch_1.bin", "H:/dataset/cifar-10/data_batch_2.bin", "H:/dataset/cifar-10/data_batch_3.bin", "H:/dataset/cifar-10/data_batch_4.bin", "H:/dataset/cifar-10/data_batch_5.bin"}, 10000, 3, 32, 32, 10, strArr, true);
                DataSet imagesToDataSetByBin2 = DataLoader.getImagesToDataSetByBin("H:/dataset/cifar-10/test_batch.bin", 10000, 3, 32, 32, 10, strArr, true, new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f});
                CNN cnn = new CNN(LossType.softmax_with_cross_entropy, UpdaterType.adam);
                cnn.CUDNN = true;
                cnn.learnRate = 0.001f;
                InputLayer inputLayer = new InputLayer(3, 32, 32);
                ConvolutionLayer convolutionLayer = new ConvolutionLayer(3, 64, 32, 32, 3, 3, 1, 1, false);
                ReluLayer reluLayer = new ReluLayer();
                ConvolutionLayer convolutionLayer2 = new ConvolutionLayer(convolutionLayer.oChannel, 64, convolutionLayer.oWidth, convolutionLayer.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer2 = new ReluLayer();
                PoolingLayer poolingLayer = new PoolingLayer(convolutionLayer2.oChannel, convolutionLayer2.oWidth, convolutionLayer2.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                ConvolutionLayer convolutionLayer3 = new ConvolutionLayer(poolingLayer.oChannel, 128, poolingLayer.oWidth, poolingLayer.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer3 = new ReluLayer();
                ConvolutionLayer convolutionLayer4 = new ConvolutionLayer(convolutionLayer3.oChannel, 128, convolutionLayer3.oWidth, convolutionLayer3.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer4 = new ReluLayer();
                PoolingLayer poolingLayer2 = new PoolingLayer(convolutionLayer4.oChannel, convolutionLayer4.oWidth, convolutionLayer4.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                ConvolutionLayer convolutionLayer5 = new ConvolutionLayer(poolingLayer2.oChannel, 256, poolingLayer2.oWidth, poolingLayer2.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer5 = new ReluLayer();
                ConvolutionLayer convolutionLayer6 = new ConvolutionLayer(convolutionLayer5.oChannel, 256, convolutionLayer5.oWidth, convolutionLayer5.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer6 = new ReluLayer();
                ConvolutionLayer convolutionLayer7 = new ConvolutionLayer(convolutionLayer6.oChannel, 256, convolutionLayer6.oWidth, convolutionLayer6.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer7 = new ReluLayer();
                PoolingLayer poolingLayer3 = new PoolingLayer(convolutionLayer7.oChannel, convolutionLayer7.oWidth, convolutionLayer7.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                ConvolutionLayer convolutionLayer8 = new ConvolutionLayer(poolingLayer3.oChannel, 512, poolingLayer3.oWidth, poolingLayer3.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer8 = new ReluLayer();
                ConvolutionLayer convolutionLayer9 = new ConvolutionLayer(convolutionLayer8.oChannel, 512, convolutionLayer8.oWidth, convolutionLayer8.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer9 = new ReluLayer();
                ConvolutionLayer convolutionLayer10 = new ConvolutionLayer(convolutionLayer9.oChannel, 512, convolutionLayer9.oWidth, convolutionLayer9.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer10 = new ReluLayer();
                ConvolutionLayer convolutionLayer11 = new ConvolutionLayer(convolutionLayer10.oChannel, 512, convolutionLayer10.oWidth, convolutionLayer10.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer11 = new ReluLayer();
                PoolingLayer poolingLayer4 = new PoolingLayer(convolutionLayer11.oChannel, convolutionLayer11.oWidth, convolutionLayer11.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                ConvolutionLayer convolutionLayer12 = new ConvolutionLayer(poolingLayer4.oChannel, 512, poolingLayer4.oWidth, poolingLayer4.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer12 = new ReluLayer();
                ConvolutionLayer convolutionLayer13 = new ConvolutionLayer(convolutionLayer12.oChannel, 512, convolutionLayer12.oWidth, convolutionLayer12.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer13 = new ReluLayer();
                ConvolutionLayer convolutionLayer14 = new ConvolutionLayer(convolutionLayer13.oChannel, 512, convolutionLayer13.oWidth, convolutionLayer13.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer14 = new ReluLayer();
                ConvolutionLayer convolutionLayer15 = new ConvolutionLayer(convolutionLayer14.oChannel, 512, convolutionLayer14.oWidth, convolutionLayer14.oHeight, 3, 3, 1, 1, false);
                ReluLayer reluLayer15 = new ReluLayer();
                PoolingLayer poolingLayer5 = new PoolingLayer(convolutionLayer15.oChannel, convolutionLayer15.oWidth, convolutionLayer15.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
                FullyLayer fullyLayer = new FullyLayer(poolingLayer5.oChannel * poolingLayer5.oWidth * poolingLayer5.oHeight, 4096, false);
                ReluLayer reluLayer16 = new ReluLayer();
                FullyLayer fullyLayer2 = new FullyLayer(4096, 4096, false);
                ReluLayer reluLayer17 = new ReluLayer();
                FullyLayer fullyLayer3 = new FullyLayer(4096, 10);
                SoftmaxWithCrossEntropyLayer softmaxWithCrossEntropyLayer = new SoftmaxWithCrossEntropyLayer(10);
                cnn.addLayer(inputLayer);
                cnn.addLayer(convolutionLayer);
                cnn.addLayer(reluLayer);
                cnn.addLayer(convolutionLayer2);
                cnn.addLayer(reluLayer2);
                cnn.addLayer(poolingLayer);
                cnn.addLayer(convolutionLayer3);
                cnn.addLayer(reluLayer3);
                cnn.addLayer(convolutionLayer4);
                cnn.addLayer(reluLayer4);
                cnn.addLayer(poolingLayer2);
                cnn.addLayer(convolutionLayer5);
                cnn.addLayer(reluLayer5);
                cnn.addLayer(convolutionLayer6);
                cnn.addLayer(reluLayer6);
                cnn.addLayer(convolutionLayer7);
                cnn.addLayer(reluLayer7);
                cnn.addLayer(poolingLayer3);
                cnn.addLayer(convolutionLayer8);
                cnn.addLayer(reluLayer8);
                cnn.addLayer(convolutionLayer9);
                cnn.addLayer(reluLayer9);
                cnn.addLayer(convolutionLayer10);
                cnn.addLayer(reluLayer10);
                cnn.addLayer(convolutionLayer11);
                cnn.addLayer(reluLayer11);
                cnn.addLayer(poolingLayer4);
                cnn.addLayer(convolutionLayer12);
                cnn.addLayer(reluLayer12);
                cnn.addLayer(convolutionLayer13);
                cnn.addLayer(reluLayer13);
                cnn.addLayer(convolutionLayer14);
                cnn.addLayer(reluLayer14);
                cnn.addLayer(convolutionLayer15);
                cnn.addLayer(reluLayer15);
                cnn.addLayer(poolingLayer5);
                cnn.addLayer(fullyLayer);
                cnn.addLayer(reluLayer16);
                cnn.addLayer(fullyLayer2);
                cnn.addLayer(reluLayer17);
                cnn.addLayer(fullyLayer3);
                cnn.addLayer(softmaxWithCrossEntropyLayer);
                MBSGDOptimizer mBSGDOptimizer = new MBSGDOptimizer((Network) cnn, 20, 0.001f, 128, LearnRateUpdate.CONSTANT, false);
                long currentTimeMillis = System.currentTimeMillis();
                mBSGDOptimizer.train(imagesToDataSetByBin);
                mBSGDOptimizer.test(imagesToDataSetByBin2);
                System.out.println(((System.currentTimeMillis() - currentTimeMillis) / 1000) + "s.");
            } catch (Exception e) {
                e.printStackTrace();
                try {
                    CUDAMemoryManager.freeAll();
                } catch (Exception e2) {
                    e2.printStackTrace();
                }
            }
        } finally {
            try {
                CUDAMemoryManager.freeAll();
            } catch (Exception e3) {
                e3.printStackTrace();
            }
        }
    }

    public static void main(String[] strArr) {
        try {
            CUDAModules.initContext();
            new vggnetTest().vgg16_cifar10();
        } finally {
            CUDAMemoryManager.free();
        }
    }
}
