package com.omega.example.gan.test;

import com.omega.common.utils.DataLoader;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.active.LeakyReluLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.active.SigmodLayer;
import com.omega.engine.nn.layer.active.TanhLayer;
import com.omega.engine.nn.network.BPNetwork;
import com.omega.engine.optimizer.GANOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.updater.UpdaterType;
import java.io.File;

/* loaded from: input_file:com/omega/example/gan/test/MinistGAN.class */
public class MinistGAN {
    public static BPNetwork NetG(int i, int i2) {
        BPNetwork bPNetwork = new BPNetwork(LossType.MSE, UpdaterType.adamw);
        bPNetwork.CUDNN = true;
        bPNetwork.learnRate = 1.0E-4f;
        InputLayer inputLayer = new InputLayer(1, 1, i2);
        FullyLayer fullyLayer = new FullyLayer(i2, 256, true);
        ReluLayer reluLayer = new ReluLayer();
        FullyLayer fullyLayer2 = new FullyLayer(256, 256, true);
        ReluLayer reluLayer2 = new ReluLayer();
        FullyLayer fullyLayer3 = new FullyLayer(256, i, true);
        TanhLayer tanhLayer = new TanhLayer();
        bPNetwork.addLayer(inputLayer);
        bPNetwork.addLayer(fullyLayer);
        bPNetwork.addLayer(reluLayer);
        bPNetwork.addLayer(fullyLayer2);
        bPNetwork.addLayer(reluLayer2);
        bPNetwork.addLayer(fullyLayer3);
        bPNetwork.addLayer(tanhLayer);
        return bPNetwork;
    }

    public BPNetwork NetD(int i) {
        BPNetwork bPNetwork = new BPNetwork(LossType.MSE, UpdaterType.adamw);
        bPNetwork.CUDNN = true;
        bPNetwork.learnRate = 1.0E-4f;
        bPNetwork.PROPAGATE_DOWN = true;
        InputLayer inputLayer = new InputLayer(1, 1, i);
        FullyLayer fullyLayer = new FullyLayer(i, 256, true);
        LeakyReluLayer leakyReluLayer = new LeakyReluLayer();
        FullyLayer fullyLayer2 = new FullyLayer(256, 256, true);
        LeakyReluLayer leakyReluLayer2 = new LeakyReluLayer();
        FullyLayer fullyLayer3 = new FullyLayer(256, 1, true);
        SigmodLayer sigmodLayer = new SigmodLayer();
        bPNetwork.addLayer(inputLayer);
        bPNetwork.addLayer(fullyLayer);
        bPNetwork.addLayer(leakyReluLayer);
        bPNetwork.addLayer(fullyLayer2);
        bPNetwork.addLayer(leakyReluLayer2);
        bPNetwork.addLayer(fullyLayer3);
        bPNetwork.addLayer(sigmodLayer);
        return bPNetwork;
    }

    public void gan_anime() {
        try {
            new GANOptimizer(NetG(784, 100), NetD(784), 2048, 3500, 1, 1, 0.001f, LearnRateUpdate.CONSTANT, false).train(DataLoader.loadDataByUByte(new File(getClass().getClassLoader().getResource("/dataset/mnist/train-images.idx3-ubyte").toURI()), new File(getClass().getClassLoader().getResource("/dataset/mnist/train-labels.idx1-ubyte").toURI()), new String[]{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}, 1, 1, 784, true, new float[]{0.5f}, new float[]{0.5f}));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        try {
            try {
                CUDAModules.initContext();
                new MinistGAN().gan_anime();
                CUDAMemoryManager.free();
            } catch (Exception e) {
                e.printStackTrace();
                CUDAMemoryManager.free();
            }
        } catch (Throwable th) {
            CUDAMemoryManager.free();
            throw th;
        }
    }
}
