package com.omega.engine.gpu.cudnn;

import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnStatus;
import jcuda.runtime.JCuda;

/* loaded from: input_file:com/omega/engine/gpu/cudnn/CudnnHandleManager.class */
public class CudnnHandleManager {
    private static cudnnHandle cudnnHandle;

    public static cudnnHandle getHandle() {
        try {
            if (cudnnHandle == null) {
                GpuHandle(0);
                System.out.printf("cudnnGetVersion() : %d , CUDNN_VERSION from cudnn.h : %d\n", Integer.valueOf((int) JCudnn.cudnnGetVersion()), 8902);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return cudnnHandle;
    }

    public static void GpuHandle(int i) {
        if (0 > i) {
            cudnnHandle = null;
            return;
        }
        initThread();
        cudnnHandle = new cudnnHandle();
        JCudnn.cudnnCreate(cudnnHandle);
    }

    public static void initThread() {
        setDevice(0);
    }

    public static void setDevice(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("cudaDeviceId=" + i);
        }
        if (isThreadDeviceId(i)) {
            return;
        }
        System.out.println(JCuda.cudaSetDevice(i));
    }

    public static boolean isThreadDeviceId(int i) {
        Integer threadDeviceId = getThreadDeviceId();
        return threadDeviceId != null && i == threadDeviceId.intValue();
    }

    public static Integer getThreadDeviceId() {
        return 0;
    }

    public static void handle(int i) {
        if (i != 0) {
            throw new CudnnError(cudnnStatus.stringFor(i));
        }
    }
}
