package com.omega.example.resnet.test;

import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.layer.AVGPoolingLayer;
import com.omega.engine.nn.layer.BasicBlockLayer;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.PoolingLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.network.CNN;
import com.omega.engine.pooling.PoolingType;
import com.omega.engine.updater.UpdaterType;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.InputStreamReader;

/* loaded from: input_file:com/omega/example/resnet/test/Resnet18.class */
public class Resnet18 {
    public static CNN instance(int i, int i2, int i3, int i4) {
        CNN cnn = new CNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw);
        cnn.CUDNN = true;
        cnn.learnRate = 0.1f;
        InputLayer inputLayer = new InputLayer(i, i2, i3);
        ConvolutionLayer convolutionLayer = new ConvolutionLayer(i, 64, i3, i2, 3, 3, 1, 1, false);
        BNLayer bNLayer = new BNLayer();
        ReluLayer reluLayer = new ReluLayer();
        PoolingLayer poolingLayer = new PoolingLayer(convolutionLayer.oChannel, convolutionLayer.oWidth, convolutionLayer.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
        BasicBlockLayer basicBlockLayer = new BasicBlockLayer(poolingLayer.oChannel, 64, poolingLayer.oHeight, poolingLayer.oWidth, 1, cnn);
        ReluLayer reluLayer2 = new ReluLayer();
        BasicBlockLayer basicBlockLayer2 = new BasicBlockLayer(basicBlockLayer.oChannel, 64, basicBlockLayer.oHeight, basicBlockLayer.oWidth, 1, cnn);
        ReluLayer reluLayer3 = new ReluLayer();
        BasicBlockLayer basicBlockLayer3 = new BasicBlockLayer(basicBlockLayer2.oChannel, 128, basicBlockLayer2.oHeight, basicBlockLayer2.oWidth, 2, cnn);
        ReluLayer reluLayer4 = new ReluLayer();
        BasicBlockLayer basicBlockLayer4 = new BasicBlockLayer(basicBlockLayer3.oChannel, 128, basicBlockLayer3.oHeight, basicBlockLayer3.oWidth, 1, cnn);
        ReluLayer reluLayer5 = new ReluLayer();
        BasicBlockLayer basicBlockLayer5 = new BasicBlockLayer(basicBlockLayer4.oChannel, 256, basicBlockLayer4.oHeight, basicBlockLayer4.oWidth, 2, cnn);
        ReluLayer reluLayer6 = new ReluLayer();
        BasicBlockLayer basicBlockLayer6 = new BasicBlockLayer(basicBlockLayer5.oChannel, 256, basicBlockLayer5.oHeight, basicBlockLayer5.oWidth, 1, cnn);
        ReluLayer reluLayer7 = new ReluLayer();
        BasicBlockLayer basicBlockLayer7 = new BasicBlockLayer(basicBlockLayer6.oChannel, 512, basicBlockLayer6.oHeight, basicBlockLayer6.oWidth, 2, cnn);
        ReluLayer reluLayer8 = new ReluLayer();
        BasicBlockLayer basicBlockLayer8 = new BasicBlockLayer(basicBlockLayer7.oChannel, 512, basicBlockLayer7.oHeight, basicBlockLayer7.oWidth, 1, cnn);
        ReluLayer reluLayer9 = new ReluLayer();
        AVGPoolingLayer aVGPoolingLayer = new AVGPoolingLayer(basicBlockLayer8.oChannel, basicBlockLayer8.oWidth, basicBlockLayer8.oHeight);
        FullyLayer fullyLayer = new FullyLayer(aVGPoolingLayer.oChannel * aVGPoolingLayer.oWidth * aVGPoolingLayer.oHeight, i4);
        cnn.addLayer(inputLayer);
        cnn.addLayer(convolutionLayer);
        cnn.addLayer(bNLayer);
        cnn.addLayer(reluLayer);
        cnn.addLayer(poolingLayer);
        cnn.addLayer(basicBlockLayer);
        cnn.addLayer(reluLayer2);
        cnn.addLayer(basicBlockLayer2);
        cnn.addLayer(reluLayer3);
        cnn.addLayer(basicBlockLayer3);
        cnn.addLayer(reluLayer4);
        cnn.addLayer(basicBlockLayer4);
        cnn.addLayer(reluLayer5);
        cnn.addLayer(basicBlockLayer5);
        cnn.addLayer(reluLayer6);
        cnn.addLayer(basicBlockLayer6);
        cnn.addLayer(reluLayer7);
        cnn.addLayer(basicBlockLayer7);
        cnn.addLayer(reluLayer8);
        cnn.addLayer(basicBlockLayer8);
        cnn.addLayer(reluLayer9);
        cnn.addLayer(aVGPoolingLayer);
        cnn.addLayer(fullyLayer);
        return cnn;
    }

    public static void loadWeight(CNN cnn, String str) {
        try {
            File file = new File(str);
            if (file.exists()) {
                FileReader fileReader = new FileReader(file);
                InputStreamReader inputStreamReader = new InputStreamReader(new FileInputStream(file), "Utf-8");
                StringBuffer stringBuffer = new StringBuffer();
                while (true) {
                    int read = inputStreamReader.read();
                    if (read == -1) {
                        break;
                    } else {
                        stringBuffer.append((char) read);
                    }
                }
                fileReader.close();
                inputStreamReader.close();
                String stringBuffer2 = stringBuffer.toString();
                System.out.println(((stringBuffer2.getBytes().length / 1024) / 1024) + "m");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        try {
            CUDAModules.initContext();
            loadWeight(instance(3, 224, 224, 1000), "H:\\voc\\train\\resnet18.json");
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            CUDAMemoryManager.free();
        }
    }
}
