package org.nd4j.jita.allocator.context.impl;

import java.lang.ref.ReferenceQueue;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.garbage.DeallocatableThread;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/allocator/context/impl/LimitedContextPool.class */
public class LimitedContextPool extends BasicContextPool {
    private static final Logger log = LoggerFactory.getLogger(LimitedContextPool.class);
    protected Map<Integer, LinkedBlockingQueue<CudaContext>> pool = new HashMap();
    protected Map<Long, CudaContext> acquired = new ConcurrentHashMap();
    protected List<AtomicInteger> devicePoolSizes = new ArrayList();
    protected Map<Integer, ReferenceQueue<Thread>> queueMap = new HashMap();
    protected ThreadLocal<Deallocatable> threadHooks = new ThreadLocal<>();

    public LimitedContextPool() {
        fillPoolWithResources(CudaEnvironment.getInstance().getConfiguration().getPoolSize(), false);
    }

    protected void addResourcesToPool(int i) {
        int intValue = AtomicAllocator.getInstance().getDeviceId().intValue();
        cublasHandle_t createNewCublasHandle = createNewCublasHandle();
        for (int i2 = 0; i2 < i; i2++) {
            CudaContext createNewStream = createNewStream(Integer.valueOf(intValue));
            createNewStream.initOldStream();
            getDeviceBuffers(createNewStream, intValue);
            createNewStream.setHandle(createNewCublasHandle);
            createNewStream.syncOldStream();
            this.pool.get(Integer.valueOf(intValue)).add(createNewStream);
        }
    }

    protected synchronized void fillPoolWithResources(int i, boolean z) {
        List<Integer> availableDevices = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices();
        int intValue = z ? AtomicAllocator.getInstance().getDeviceId().intValue() : 0;
        NativeOps deviceNativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        for (Integer num : availableDevices) {
            deviceNativeOps.setDevice(num.intValue());
            this.pool.put(num, new LinkedBlockingQueue<>());
            this.devicePoolSizes.add(new AtomicInteger(i));
            cublasHandle_t createNewCublasHandle = createNewCublasHandle();
            cusolverDnHandle_t createNewSolverHandle = createNewSolverHandle();
            for (int i2 = 0; i2 < i; i2++) {
                CudaContext createNewStream = createNewStream(num);
                createNewStream.initOldStream();
                getDeviceBuffers(createNewStream, num.intValue());
                createNewStream.setHandle(createNewCublasHandle);
                createNewStream.setSolverHandle(createNewSolverHandle);
                createNewStream.syncOldStream();
                this.pool.get(num).add(createNewStream);
            }
        }
        if (z) {
            deviceNativeOps.setDevice(intValue);
        }
    }

    public void removeAcquired() {
        this.acquired.remove(Long.valueOf(Thread.currentThread().getId()));
    }

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool, org.nd4j.jita.allocator.context.ContextPool
    public CudaContext acquireContextForDevice(Integer num) {
        CudaContext poll;
        long id = Thread.currentThread().getId();
        CudaContext cudaContext = this.acquired.get(Long.valueOf(id));
        if (cudaContext != null && num.intValue() == cudaContext.getDeviceId()) {
            return cudaContext;
        }
        this.nativeOps.setDevice(num.intValue());
        CudaContext poll2 = this.pool.get(num).poll();
        if (poll2 != null) {
            poll2.setDeviceId(num.intValue());
            poll2.setThreadId(id);
            DeallocatableThread deallocatableThread = new DeallocatableThread(Thread.currentThread(), poll2);
            this.threadHooks.set(deallocatableThread);
            Nd4j.getDeallocatorService().pickObject(deallocatableThread);
            this.acquired.put(Long.valueOf(id), poll2);
            return poll2;
        }
        do {
            try {
                Nd4j.getMemoryManager().invokeGc();
                poll = this.pool.get(num).poll(1L, TimeUnit.SECONDS);
                if (poll != null) {
                    poll.setDeviceId(num.intValue());
                    poll.setThreadId(id);
                    DeallocatableThread deallocatableThread2 = new DeallocatableThread(Thread.currentThread(), poll);
                    this.threadHooks.set(deallocatableThread2);
                    Nd4j.getDeallocatorService().pickObject(deallocatableThread2);
                    this.acquired.put(Long.valueOf(id), poll);
                } else {
                    AtomicInteger atomicInteger = this.devicePoolSizes.get(num.intValue());
                    synchronized (atomicInteger) {
                        if (atomicInteger.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) {
                            addResourcesToPool(16);
                            atomicInteger.addAndGet(16);
                            log.warn("Initial pool size: {}; Current pool size: {}", Integer.valueOf(CudaEnvironment.getInstance().getConfiguration().getPoolSize()), Integer.valueOf(atomicInteger.get()));
                        } else {
                            log.warn("Can't allocate new context, sleeping...");
                            Nd4j.getMemoryManager().invokeGc();
                            try {
                                Thread.sleep(500L);
                            } catch (Exception e) {
                            }
                        }
                    }
                }
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        } while (poll == null);
        return poll;
    }

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool, org.nd4j.jita.allocator.context.ContextPool
    @Deprecated
    public ContextPack acquireContextPackForDevice(Integer num) {
        return new ContextPack(acquireContextForDevice(num));
    }

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool
    public CudaContext getContextForDevice(Integer num) {
        return acquireContextForDevice(num);
    }

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool, org.nd4j.jita.allocator.context.ContextPool
    public void releaseContext(CudaContext cudaContext) {
        long threadId = cudaContext.getThreadId();
        int deviceId = cudaContext.getDeviceId();
        cudaContext.setThreadId(-1L);
        this.acquired.remove(Long.valueOf(threadId));
        this.pool.get(Integer.valueOf(deviceId)).add(cudaContext);
    }
}
