package com.omega.engine.gpu;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUfunction;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/gpu/CUDAModules.class */
public class CUDAModules {
    private static CUdevice device;
    private static CUcontext context;
    private static CUDAUtils instance;
    public static int maxThreads;
    public static int threadsPerDimension;
    public static cudaDeviceProp props;
    public static Map<String, MyCUDAModule> modules = new HashMap();
    public static Map<String, String> functions = new HashMap<String, String>() { // from class: com.omega.engine.gpu.CUDAModules.1
        private static final long serialVersionUID = -7636602208380901817L;

        {
            put("col2im_gpu_kernelV2", "Col2imKernel.cu");
            put("im2col_gpu_kernelV2", "Im2colKernel.cu");
            put("pooling_diff", "PoolingKernel.cu");
            put("max_pooling", "PoolingKernel.cu");
            put("mean_pooling", "PoolingKernel.cu");
            put("mean_cov", "MathKernel.cu");
            put("fast_mean_kernel", "MathKernel.cu");
            put("var_cov", "MathKernel.cu");
            put("fast_variance_kernel", "MathKernel.cu");
            put("normalize_kernel", "BNKernel.cu");
            put("std_fn", "MathKernel.cu");
            put("mwa", "MathKernel.cu");
            put("culOutput_cov", "BNKernel.cu");
            put("computeDelta", "BNKernel.cu");
            put("computeDelta_full", "BNKernel.cu");
            put("meanDzSum", "BNKernel.cu");
            put("computeDiff", "BNKernel.cu");
            put("dgama_kernel", "BNKernel.cu");
            put("dbeta_kernel", "BNKernel.cu");
            put("dxhat_kernel2", "BNKernel.cu");
            put("full_mean_delta_kernel", "BNKernel.cu");
            put("full_var_delta_kernel", "BNKernel.cu");
            put("fast_variance_delta_kernel", "BNKernel.cu");
            put("dx_kernel", "BNKernel.cu");
            put("dx_kernel_full", "BNKernel.cu");
            put("copy_kernel", "OPKernel.cu");
            put("copy_number_kernel", "OPKernel.cu");
            put("copy_channel_kernel", "OPKernel.cu");
            put("add_kernel", "OPKernel.cu");
            put("add_scalar_kernel", "OPKernel.cu");
            put("add_number_kernel", "OPKernel.cu");
            put("add_channel_kernel", "OPKernel.cu");
            put("sub_kernel", "OPKernel.cu");
            put("sub_scalar_kernel", "OPKernel.cu");
            put("mul_kernel", "OPKernel.cu");
            put("mul_scalar_kernel", "OPKernel.cu");
            put("mul_plus_kernel", "OPKernel.cu");
            put("mul_plus_scalar_kernel", "OPKernel.cu");
            put("div_kernel", "OPKernel.cu");
            put("div_scalar_kernel", "OPKernel.cu");
            put("scalar_div_kernel", "OPKernel.cu");
            put("div_plus_kernel", "OPKernel.cu");
            put("div_plus_scalar_kernel", "OPKernel.cu");
            put("scalar_plus_div_kernel", "OPKernel.cu");
            put("div_bGrad_kernel", "OPKernel.cu");
            put("div_scalar_bGrad_kernel", "OPKernel.cu");
            put("pow_kernel", "OPKernel.cu");
            put("log_kernel", "OPKernel.cu");
            put("exp_kernel", "OPKernel.cu");
            put("sin_kernel", "OPKernel.cu");
            put("cos_kernel", "OPKernel.cu");
        }
    };

    public static CUfunction getLocalFunctionByModule(String str, String str2) {
        String str3;
        if (System.getProperty("os.name").contains("windows")) {
            String path = CUDAModules.class.getResource("/cu/").getPath();
            str3 = path.substring(1, path.length()) + str;
        } else {
            str3 = CUDAModules.class.getResource("/cu/").getPath() + str;
        }
        System.out.println(str3);
        MyCUDAModule module = getModule(str3);
        if (module.getFunctions().containsKey(str2)) {
            return module.getFunctions().get(str2);
        }
        CUfunction cUfunction = new CUfunction();
        checkCUDA(JCudaDriver.cuModuleGetFunction(cUfunction, module, str2));
        module.getFunctions().put(str2, cUfunction);
        return cUfunction;
    }

    public static CUfunction getEXFunctionByModule(String str, String str2) {
        MyCUDAModule module = getModule(str);
        if (module.getFunctions().containsKey(str2)) {
            return module.getFunctions().get(str2);
        }
        CUfunction cUfunction = new CUfunction();
        checkCUDA(JCudaDriver.cuModuleGetFunction(cUfunction, module, str2));
        module.getFunctions().put(str2, cUfunction);
        return cUfunction;
    }

    public static MyCUDAModule getModule(String str) {
        try {
            String preparePtxFile = preparePtxFile(str);
            if (modules.containsKey(preparePtxFile)) {
                return modules.get(preparePtxFile);
            }
            setContext(getContext());
            maxThreads = instance.getMaxThreads(device);
            threadsPerDimension = (int) Math.sqrt(maxThreads);
            MyCUDAModule myCUDAModule = new MyCUDAModule();
            JCudaDriver.cuModuleLoad(myCUDAModule, preparePtxFile);
            modules.put(preparePtxFile, myCUDAModule);
            return myCUDAModule;
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    private static String preparePtxFile(String str) throws IOException {
        int lastIndexOf = str.lastIndexOf(46);
        if (lastIndexOf == -1) {
            lastIndexOf = str.length() - 1;
        }
        String str2 = str.substring(0, lastIndexOf + 1) + "ptx";
        if (new File(str2).exists()) {
            return str2;
        }
        System.out.println(str2);
        File file = new File(str);
        if (!file.exists()) {
            throw new IOException("Input file not found: " + str);
        }
        String str3 = "nvcc " + ("-m" + System.getProperty("sun.arch.data.model")) + " -ptx " + file.getPath() + " -o " + str2;
        System.out.println("Executing\n" + str3);
        Process exec = Runtime.getRuntime().exec(str3);
        String str4 = new String(toByteArray(exec.getErrorStream()));
        String str5 = new String(toByteArray(exec.getInputStream()));
        try {
            int waitFor = exec.waitFor();
            if (waitFor == 0) {
                System.out.println("Finished creating PTX file");
                return str2;
            }
            System.out.println("nvcc process exitValue " + waitFor);
            System.out.println("errorMessage:\n" + str4);
            System.out.println("outputMessage:\n" + str5);
            throw new IOException("Could not create .ptx file: " + str4);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IOException("Interrupted while waiting for nvcc output", e);
        }
    }

    private static byte[] toByteArray(InputStream inputStream) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] bArr = new byte[8192];
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                return byteArrayOutputStream.toByteArray();
            }
            byteArrayOutputStream.write(bArr, 0, read);
        }
    }

    public static CUcontext getContext() {
        if (context == null) {
            JCudaDriver.setExceptionsEnabled(true);
            instance = CUDAUtils.getInstance();
            instance.initCUDA();
            device = instance.getDevice(0);
            context = instance.getContext(device);
            props = new cudaDeviceProp();
            JCuda.cudaGetDeviceProperties(props, 0);
            System.out.println("CUDA context init finish.");
        }
        return context;
    }

    public static void initContext() {
        getContext();
    }

    public static void setContext(CUcontext cUcontext) {
        context = cUcontext;
    }

    public static void initCUDAFunctions() {
        for (String str : functions.keySet()) {
            getLocalFunctionByModule(functions.get(str), str);
        }
        System.out.println("CUDA functions init finish.");
    }

    public static void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
            throw new RuntimeException("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }
}
