package org.nd4j.linalg.jcublas.context;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.HashMap;
import java.util.Map;
import jcuda.CudaException;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUresult;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;

/* loaded from: input_file:org/nd4j/linalg/jcublas/context/ContextHolder.class */
public class ContextHolder {
    private Map<Integer, CUdevice> devices = new HashMap();
    private Map<Integer, CUcontext> deviceIDContexts = new HashMap();
    private Table<CUcontext, String, CUstream> contextStreams = HashBasedTable.create();
    private int numDevices = 0;
    private static ContextHolder INSTANCE;

    private ContextHolder() {
        getNumDevices();
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { // from class: org.nd4j.linalg.jcublas.context.ContextHolder.1
            @Override // java.lang.Runnable
            public void run() {
            }
        }));
    }

    public static ContextHolder getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new ContextHolder();
        }
        return INSTANCE;
    }

    private void getNumDevices() {
        int[] iArr = new int[1];
        JCudaDriver.cuDeviceGetCount(iArr);
        this.numDevices = iArr[0];
        if (this.numDevices < 1) {
            this.numDevices = 1;
        }
    }

    public synchronized CUcontext getContext() {
        return getContext(0);
    }

    public synchronized CUstream getStream() {
        Thread currentThread = Thread.currentThread();
        CUcontext context = getContext(0);
        CUstream cUstream = (CUstream) this.contextStreams.get(context, currentThread.getName());
        if (cUstream == null) {
            cUstream = new CUstream();
            int cuStreamCreate = JCudaDriver.cuStreamCreate(cUstream, 0);
            if (cuStreamCreate != 0) {
                throw new CudaException("Failed to create a stream: " + CUresult.stringFor(cuStreamCreate));
            }
            this.contextStreams.put(context, currentThread.getName(), cUstream);
        }
        return cUstream;
    }

    public synchronized CUcontext getContext(int i) {
        CUcontext cUcontext = this.deviceIDContexts.get(0);
        if (cUcontext == null) {
            cUcontext = new CUcontext();
            for (int i2 = 0; i2 < this.numDevices; i2++) {
                initialize(cUcontext, i2);
                this.devices.put(Integer.valueOf(i2), createDevice(cUcontext, i2));
                this.deviceIDContexts.put(Integer.valueOf(i2), cUcontext);
            }
        }
        return cUcontext;
    }

    private void initialize(CUcontext cUcontext, int i) {
        int cuInit = JCudaDriver.cuInit(0);
        if (cuInit != 0) {
            throw new CudaException("Failed to initialize the driver: " + CUresult.stringFor(cuInit));
        }
        int cuCtxGetCurrent = JCudaDriver.cuCtxGetCurrent(cUcontext);
        if (cuCtxGetCurrent != 0) {
            throw new CudaException("Failed to obtain the current context: " + CUresult.stringFor(cuCtxGetCurrent));
        }
        if (cUcontext.equals(new CUcontext())) {
            createContext(cUcontext, i);
        }
    }

    private void createContext(CUcontext cUcontext, int i) {
        CUdevice cUdevice = new CUdevice();
        int cuDeviceGet = JCudaDriver.cuDeviceGet(cUdevice, i);
        if (cuDeviceGet != 0) {
            throw new CudaException("Failed to obtain a device: " + CUresult.stringFor(cuDeviceGet));
        }
        int cuCtxCreate = JCudaDriver.cuCtxCreate(cUcontext, 0, cUdevice);
        if (cuCtxCreate != 0) {
            throw new CudaException("Failed to create a context: " + CUresult.stringFor(cuCtxCreate));
        }
    }

    public static CUdevice createDevice(CUcontext cUcontext, int i) {
        CUdevice cUdevice = new CUdevice();
        int cuDeviceGet = JCudaDriver.cuDeviceGet(cUdevice, i);
        if (cuDeviceGet != 0) {
            throw new CudaException("Failed to obtain a device: " + CUresult.stringFor(cuDeviceGet));
        }
        int cuCtxCreate = JCudaDriver.cuCtxCreate(cUcontext, 0, cUdevice);
        if (cuCtxCreate != 0) {
            throw new CudaException("Failed to create a context: " + CUresult.stringFor(cuCtxCreate));
        }
        return cUdevice;
    }

    public Map<Integer, CUdevice> getDevices() {
        return this.devices;
    }

    public Map<Integer, CUcontext> getDeviceIDContexts() {
        return this.deviceIDContexts;
    }
}
