package com.omega.example.gan.test;

import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.ConvolutionTransposeLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.active.LeakyReluLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.active.SigmodLayer;
import com.omega.engine.nn.layer.active.TanhLayer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.network.CNN;
import com.omega.engine.optimizer.GANOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.updater.UpdaterType;
import com.omega.example.gan.utils.ImageDataLoader;
import java.util.HashMap;

/* loaded from: input_file:com/omega/example/gan/test/DCGAN.class */
public class DCGAN {
    public static CNN NetG(int i, int i2) {
        CNN cnn = new CNN(LossType.BCE, UpdaterType.adamw);
        cnn.CUDNN = true;
        cnn.learnRate = 3.0E-4f;
        InputLayer inputLayer = new InputLayer(1, 1, i2);
        FullyLayer fullyLayer = new FullyLayer(i2, i * 8 * 4 * 4, false);
        BNLayer bNLayer = new BNLayer();
        ReluLayer reluLayer = new ReluLayer();
        ConvolutionTransposeLayer convolutionTransposeLayer = new ConvolutionTransposeLayer(i * 8, i * 4, 4, 4, 5, 5, 2, 2, 1, 1, false);
        BNLayer bNLayer2 = new BNLayer();
        ReluLayer reluLayer2 = new ReluLayer();
        ConvolutionTransposeLayer convolutionTransposeLayer2 = new ConvolutionTransposeLayer(convolutionTransposeLayer.oChannel, i * 2, convolutionTransposeLayer.oWidth, convolutionTransposeLayer.oHeight, 5, 5, 2, 2, 1, 1, false);
        BNLayer bNLayer3 = new BNLayer();
        ReluLayer reluLayer3 = new ReluLayer();
        ConvolutionTransposeLayer convolutionTransposeLayer3 = new ConvolutionTransposeLayer(convolutionTransposeLayer2.oChannel, i, convolutionTransposeLayer2.oWidth, convolutionTransposeLayer2.oHeight, 5, 5, 2, 2, 1, 1, false);
        BNLayer bNLayer4 = new BNLayer();
        ReluLayer reluLayer4 = new ReluLayer();
        ConvolutionTransposeLayer convolutionTransposeLayer4 = new ConvolutionTransposeLayer(convolutionTransposeLayer3.oChannel, 3, convolutionTransposeLayer3.oWidth, convolutionTransposeLayer3.oHeight, 5, 5, 2, 2, 1, 1, true);
        TanhLayer tanhLayer = new TanhLayer();
        cnn.addLayer(inputLayer);
        cnn.addLayer(fullyLayer);
        cnn.addLayer(bNLayer);
        cnn.addLayer(reluLayer);
        cnn.addLayer(convolutionTransposeLayer);
        cnn.addLayer(bNLayer2);
        cnn.addLayer(reluLayer2);
        cnn.addLayer(convolutionTransposeLayer2);
        cnn.addLayer(bNLayer3);
        cnn.addLayer(reluLayer3);
        cnn.addLayer(convolutionTransposeLayer3);
        cnn.addLayer(bNLayer4);
        cnn.addLayer(reluLayer4);
        cnn.addLayer(convolutionTransposeLayer4);
        cnn.addLayer(tanhLayer);
        return cnn;
    }

    public static CNN NetD(int i, int i2, int i3) {
        CNN cnn = new CNN(LossType.BCE, UpdaterType.adamw);
        cnn.updaterParams = new HashMap<String, Float>() { // from class: com.omega.example.gan.test.DCGAN.1
        };
        cnn.CUDNN = true;
        cnn.learnRate = 3.0E-4f;
        cnn.PROPAGATE_DOWN = true;
        InputLayer inputLayer = new InputLayer(3, i3, i2);
        ConvolutionLayer convolutionLayer = new ConvolutionLayer(3, i, i2, i3, 5, 5, 2, 2, true);
        BNLayer bNLayer = new BNLayer();
        LeakyReluLayer leakyReluLayer = new LeakyReluLayer();
        ConvolutionLayer convolutionLayer2 = new ConvolutionLayer(convolutionLayer.oChannel, i * 2, convolutionLayer.oWidth, convolutionLayer.oHeight, 5, 5, 2, 2, false);
        BNLayer bNLayer2 = new BNLayer();
        LeakyReluLayer leakyReluLayer2 = new LeakyReluLayer();
        ConvolutionLayer convolutionLayer3 = new ConvolutionLayer(convolutionLayer2.oChannel, i * 4, convolutionLayer2.oWidth, convolutionLayer2.oHeight, 5, 5, 2, 2, false);
        BNLayer bNLayer3 = new BNLayer();
        LeakyReluLayer leakyReluLayer3 = new LeakyReluLayer();
        ConvolutionLayer convolutionLayer4 = new ConvolutionLayer(convolutionLayer3.oChannel, i * 8, convolutionLayer3.oWidth, convolutionLayer3.oHeight, 5, 5, 2, 2, false);
        BNLayer bNLayer4 = new BNLayer();
        LeakyReluLayer leakyReluLayer4 = new LeakyReluLayer();
        ConvolutionLayer convolutionLayer5 = new ConvolutionLayer(convolutionLayer4.oChannel, 1, convolutionLayer4.oWidth, convolutionLayer4.oHeight, 4, 4, 0, 1, true);
        SigmodLayer sigmodLayer = new SigmodLayer();
        cnn.addLayer(inputLayer);
        cnn.addLayer(convolutionLayer);
        cnn.addLayer(bNLayer);
        cnn.addLayer(leakyReluLayer);
        cnn.addLayer(convolutionLayer2);
        cnn.addLayer(bNLayer2);
        cnn.addLayer(leakyReluLayer2);
        cnn.addLayer(convolutionLayer3);
        cnn.addLayer(bNLayer3);
        cnn.addLayer(leakyReluLayer3);
        cnn.addLayer(convolutionLayer4);
        cnn.addLayer(bNLayer4);
        cnn.addLayer(leakyReluLayer4);
        cnn.addLayer(convolutionLayer5);
        cnn.addLayer(sigmodLayer);
        return cnn;
    }

    public static void gan_anime() {
        float[] fArr = {0.5f, 0.5f, 0.5f};
        float[] fArr2 = {0.5f, 0.5f, 0.5f};
        try {
            new GANOptimizer(NetG(64, 100), NetD(64, 64, 64), 64, 100, 1, 1, 0.001f, LearnRateUpdate.POLY, false).train(new ImageDataLoader("H:\\voc\\gan_anime\\ml2021spring-hw6\\faces\\", 64, 64, 64, true, fArr, fArr2));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        try {
            CUDAModules.initContext();
            gan_anime();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            CUDAMemoryManager.free();
        }
    }
}
