package org.nd4j.jita.handler.impl;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.concurrency.DeviceAllocationsTracker;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.flow.impl.GridFlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemcpyDirection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.bindings.Nd4jCuda;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueLaunchContext;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/handler/impl/CudaZeroHandler.class */
public class CudaZeroHandler implements MemoryHandler {
    private static Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    private static Logger log = LoggerFactory.getLogger(CudaZeroHandler.class);
    protected volatile DeviceAllocationsTracker deviceMemoryTracker;
    private final FlowController flowController;
    private final AllocationStatus INITIAL_LOCATION;
    protected final AtomicLong zeroUseCounter = new AtomicLong(0);
    protected Map<Long, Integer> devicesAffinity = new ConcurrentHashMap();
    private ReentrantReadWriteLock deviceLock = new ReentrantReadWriteLock();
    private AtomicInteger devPtr = new AtomicInteger(0);
    private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
    private final List<cublasHandle_t> cublasHandles = new ArrayList();
    private final AffinityManager affinityManager = Nd4j.getAffinityManager();
    private final transient ThreadLocal<CudaContext> tlContext = new ThreadLocal<>();
    private final List<ConcurrentHashMap<Long, Long>> deviceAllocations = new ArrayList();
    private final Map<Long, ConcurrentHashMap<Long, Long>> zeroAllocations = new ConcurrentHashMap();
    private AtomicLong zeroCounter = new AtomicLong(0);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

    /* renamed from: org.nd4j.jita.handler.impl.CudaZeroHandler$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/jita/handler/impl/CudaZeroHandler$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT32.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.INT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.SHORT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT16.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.HALF.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BFLOAT16.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT64.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.LONG.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UTF8.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UBYTE.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BYTE.ordinal()] = 13;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BOOL.ordinal()] = 14;
            } catch (NoSuchFieldError e14) {
            }
            $SwitchMap$org$nd4j$jita$conf$Configuration$ExecutionModel = new int[Configuration.ExecutionModel.values().length];
            try {
                $SwitchMap$org$nd4j$jita$conf$Configuration$ExecutionModel[Configuration.ExecutionModel.SEQUENTIAL.ordinal()] = 1;
            } catch (NoSuchFieldError e15) {
            }
        }
    }

    public CudaZeroHandler() {
        configuration.setInitialized();
        this.INITIAL_LOCATION = configuration.getFirstMemory();
        switch (configuration.getExecutionModel()) {
            case SEQUENTIAL:
                this.flowController = new GridFlowController();
                int availableDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
                for (int i = 0; i < availableDevices; i++) {
                    this.deviceAllocations.add(new ConcurrentHashMap<>());
                    this.cublasHandles.add(null);
                }
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
                    throw new ND4JIllegalStateException("CUDA backend requires compute capatibility of 3.0 and above to run.");
                }
                return;
            default:
                throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]");
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void init(@NonNull Configuration configuration2, @NonNull Allocator allocator) {
        if (configuration2 == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (allocator == null) {
            throw new NullPointerException("allocator is marked non-null but is null");
        }
        configuration = configuration2;
        this.deviceMemoryTracker = new DeviceAllocationsTracker(configuration);
        this.flowController.init(allocator);
    }

    private void pickupHostAllocation(AllocationPoint allocationPoint) {
        long nextInt = RandomUtils.nextInt(0, configuration.getNumberOfGcThreads());
        this.zeroUseCounter.addAndGet(allocationPoint.getNumberOfBytes());
        allocationPoint.setBucketId(Long.valueOf(nextInt));
        if (!this.zeroAllocations.containsKey(Long.valueOf(nextInt))) {
            log.debug("Creating bucketID: " + nextInt);
            synchronized (this) {
                if (!this.zeroAllocations.containsKey(Long.valueOf(nextInt))) {
                    this.zeroAllocations.put(Long.valueOf(nextInt), new ConcurrentHashMap<>());
                }
            }
        }
        this.zeroAllocations.get(Long.valueOf(nextInt)).put(allocationPoint.getObjectId(), allocationPoint.getObjectId());
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public PointersPair alloc(AllocationStatus allocationStatus, AllocationPoint allocationPoint, AllocationShape allocationShape, boolean z) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public boolean pingDeviceForFreeMemory(Integer num, long j) {
        return true;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void relocate(AllocationStatus allocationStatus, AllocationStatus allocationStatus2, AllocationPoint allocationPoint, AllocationShape allocationShape, CudaContext cudaContext) {
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    @Deprecated
    public void copyback(AllocationPoint allocationPoint, AllocationShape allocationShape) {
        throw new UnsupportedOperationException("Deprecated call");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    @Deprecated
    public void copyforward(AllocationPoint allocationPoint, AllocationShape allocationShape) {
        throw new UnsupportedOperationException("Deprecated call");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    @Deprecated
    public void fallback(AllocationPoint allocationPoint, AllocationShape allocationShape) {
        throw new IllegalStateException("Can't fallback from [" + allocationPoint.getAllocationStatus() + "]");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void free(AllocationPoint allocationPoint, AllocationStatus allocationStatus) {
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public AllocationStatus getInitialLocation() {
        return this.INITIAL_LOCATION;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void initializeDevice(Long l, Integer num) {
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpyAsync(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        if (j < 1) {
            return;
        }
        Preconditions.checkArgument(j <= dataBuffer.length() * ((long) Nd4j.sizeOfDataType(dataBuffer.dataType())), "Length requested is bigger than target DataBuffer length");
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        CudaContext cudaContext = null;
        if (dataBuffer.isConstant()) {
            CudaPointer cudaPointer = new CudaPointer(allocationPoint.getHostPointer().address() + j2, 0L);
            CudaPointer cudaPointer2 = new CudaPointer(pointer, j);
            long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy(cudaPointer, cudaPointer2, j);
            PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), helperStartTransaction, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
            allocationPoint.tickHostRead();
            return;
        }
        CudaPointer cudaPointer3 = new CudaPointer(allocationPoint.getDevicePointer().address() + j2);
        if (0 == 0) {
            cudaContext = this.flowController.prepareAction(allocationPoint, new AllocationPoint[0]);
        }
        long helperStartTransaction2 = PerformanceTracker.getInstance().helperStartTransaction();
        this.flowController.commitTransfer(cudaContext.getSpecialStream());
        if (this.nativeOps.memcpyAsync(cudaPointer3, pointer, j, CudaConstants.cudaMemcpyHostToDevice, cudaContext.getSpecialStream()) == 0) {
            throw new IllegalStateException("MemcpyAsync H2D failed: [" + pointer.address() + "] -> [" + cudaPointer3.address() + "]");
        }
        this.flowController.commitTransfer(cudaContext.getSpecialStream());
        PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), helperStartTransaction2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
        this.flowController.registerAction(cudaContext, allocationPoint, new AllocationPoint[0]);
        allocationPoint.tickDeviceWrite();
        if (allocationPoint.getHostPointer() != null) {
            CudaPointer cudaPointer4 = new CudaPointer(allocationPoint.getHostPointer().address() + j2);
            CudaContext prepareAction = this.flowController.prepareAction(allocationPoint, new AllocationPoint[0]);
            long helperStartTransaction3 = PerformanceTracker.getInstance().helperStartTransaction();
            if (this.nativeOps.memcpyAsync(cudaPointer4, pointer, j, CudaConstants.cudaMemcpyHostToHost, prepareAction.getSpecialStream()) == 0) {
                throw new IllegalStateException("MemcpyAsync H2H failed: [" + pointer.address() + "] -> [" + cudaPointer4.address() + "]");
            }
            this.flowController.commitTransfer(prepareAction.getSpecialStream());
            PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), helperStartTransaction3, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
            if (allocationPoint.getAllocationStatus() == AllocationStatus.HOST) {
                this.flowController.registerAction(prepareAction, allocationPoint, new AllocationPoint[0]);
            }
            allocationPoint.tickHostRead();
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpyDevice(DataBuffer dataBuffer, Pointer pointer, long j, long j2, CudaContext cudaContext) {
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        if (this.nativeOps.memcpyAsync(new CudaPointer(allocationPoint.getDevicePointer().address() + j2), pointer, j, CudaConstants.cudaMemcpyDeviceToDevice, cudaContext.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
        allocationPoint.tickDeviceWrite();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpySpecial(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        CudaContext cudaContext = getCudaContext();
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        CudaPointer cudaPointer = new CudaPointer(allocationPoint.getHostPointer().address() + j2);
        long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
        if (this.nativeOps.memcpyAsync(cudaPointer, pointer, j, CudaConstants.cudaMemcpyHostToHost, cudaContext.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
        PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), helperStartTransaction, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            CudaPointer cudaPointer2 = new CudaPointer(allocationPoint.getDevicePointer().address() + j2);
            long helperStartTransaction2 = PerformanceTracker.getInstance().helperStartTransaction();
            if (this.nativeOps.memcpyAsync(cudaPointer2, cudaPointer, j, CudaConstants.cudaMemcpyHostToDevice, cudaContext.getOldStream()) == 0) {
                throw new ND4JIllegalStateException("memcpyAsync failed");
            }
            cudaContext.syncOldStream();
            PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), helperStartTransaction2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
        }
        cudaContext.syncOldStream();
        allocationPoint.tickDeviceWrite();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpyBlocking(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        CudaContext cudaContext = getCudaContext();
        memcpyAsync(dataBuffer, pointer, j, j2);
        cudaContext.syncOldStream();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpy(DataBuffer dataBuffer, DataBuffer dataBuffer2) {
        MemcpyDirection memcpyDirection;
        CudaContext cudaContext = getCudaContext();
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        AllocationPoint allocationPoint2 = ((BaseCudaDataBuffer) dataBuffer2).getAllocationPoint();
        long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
        Nd4j.getExecutioner().push();
        if (allocationPoint2.isActualOnDeviceSide()) {
            if (this.nativeOps.memcpyAsync(AtomicAllocator.getInstance().getPointer(dataBuffer, cudaContext), AtomicAllocator.getInstance().getPointer(dataBuffer2, cudaContext), dataBuffer2.length() * dataBuffer2.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, cudaContext.getOldStream()) == 0) {
                throw new ND4JIllegalStateException("memcpyAsync failed");
            }
            allocationPoint.tickDeviceWrite();
            memcpyDirection = MemcpyDirection.DEVICE_TO_DEVICE;
        } else {
            if (this.nativeOps.memcpyAsync(AtomicAllocator.getInstance().getPointer(dataBuffer, cudaContext), AtomicAllocator.getInstance().getHostPointer(dataBuffer2), dataBuffer2.length() * dataBuffer2.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, cudaContext.getOldStream()) == 0) {
                throw new ND4JIllegalStateException("memcpyAsync failed");
            }
            memcpyDirection = MemcpyDirection.HOST_TO_DEVICE;
        }
        allocationPoint.tickDeviceWrite();
        cudaContext.syncOldStream();
        PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint2.getDeviceId(), helperStartTransaction / 2, allocationPoint.getNumberOfBytes(), memcpyDirection);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Pointer getDevicePointer(DataBuffer dataBuffer, CudaContext cudaContext) {
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE && !allocationPoint.isActualOnDeviceSide()) {
            throw new UnsupportedOperationException("Pew-pew");
        }
        if (allocationPoint.getDevicePointer() == null) {
            return null;
        }
        CudaPointer cudaPointer = new CudaPointer(allocationPoint.getDevicePointer(), dataBuffer.length(), 0L);
        if (OpProfiler.getInstance().getConfig().isCheckLocality()) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer(cudaContext.getOldStream(), cudaPointer, 1);
        }
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataBuffer.dataType().ordinal()]) {
            case 1:
                return cudaPointer.asDoublePointer();
            case 2:
                return cudaPointer.asFloatPointer();
            case 3:
            case 4:
                return cudaPointer.asIntPointer();
            case Nd4jCuda.FLOAT32 /* 5 */:
            case Nd4jCuda.DOUBLE /* 6 */:
            case Nd4jCuda.INT8 /* 7 */:
            case Nd4jCuda.INT16 /* 8 */:
                return cudaPointer.asShortPointer();
            case Nd4jCuda.INT32 /* 9 */:
            case 10:
                return cudaPointer.asLongPointer();
            case Nd4jCuda.UINT8 /* 11 */:
            case 12:
            case Nd4jCuda.UINT32 /* 13 */:
                return cudaPointer.asBytePointer();
            case Nd4jCuda.UINT64 /* 14 */:
                return cudaPointer.asBooleanPointer();
            default:
                return cudaPointer;
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Pointer getHostPointer(DataBuffer dataBuffer) {
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        if (allocationPoint.getHostPointer() == null) {
            return null;
        }
        synchronizeThreadDevice(Long.valueOf(Thread.currentThread().getId()), Integer.valueOf(allocationPoint.getDeviceId()), allocationPoint);
        CudaPointer cudaPointer = new CudaPointer(allocationPoint.getHostPointer(), dataBuffer.length(), 0L);
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataBuffer.dataType().ordinal()]) {
            case 1:
                return cudaPointer.asDoublePointer();
            case 2:
                return cudaPointer.asFloatPointer();
            case 3:
            case 4:
                return cudaPointer.asIntPointer();
            case Nd4jCuda.FLOAT32 /* 5 */:
            case Nd4jCuda.DOUBLE /* 6 */:
            case Nd4jCuda.INT8 /* 7 */:
            case Nd4jCuda.INT16 /* 8 */:
                return cudaPointer.asShortPointer();
            case Nd4jCuda.INT32 /* 9 */:
            case 10:
                return cudaPointer.asLongPointer();
            default:
                return cudaPointer;
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public synchronized void relocateObject(DataBuffer dataBuffer) {
        AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
        throw new UnsupportedOperationException("Pew-pew");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public boolean promoteObject(DataBuffer dataBuffer) {
        AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
        throw new UnsupportedOperationException("Pew-pew");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Table<AllocationStatus, Integer, Long> getAllocationStatistics() {
        HashBasedTable create = HashBasedTable.create();
        create.put(AllocationStatus.HOST, 0, Long.valueOf(this.zeroUseCounter.get()));
        for (Integer num : configuration.getAvailableDevices()) {
            create.put(AllocationStatus.DEVICE, num, Long.valueOf(getAllocatedDeviceMemory(num)));
        }
        return create;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedDeviceMemory(Integer num) {
        return this.deviceMemoryTracker.getAllocatedSize(num);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedHostMemory() {
        return this.zeroUseCounter.get();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedDeviceObjects(Integer num) {
        return this.deviceAllocations.get(num.intValue()).size();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedHostObjects(Long l) {
        if (this.zeroAllocations.containsKey(l)) {
            return this.zeroAllocations.get(l).size();
        }
        return 0L;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedHostObjects() {
        AtomicLong atomicLong = new AtomicLong(0L);
        Iterator<Long> it = this.zeroAllocations.keySet().iterator();
        while (it.hasNext()) {
            atomicLong.addAndGet(this.zeroAllocations.get(it.next()).size());
        }
        return atomicLong.get();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Set<Long> getDeviceTrackingPoints(Integer num) {
        return this.deviceAllocations.get(num.intValue()).keySet();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Set<Long> getHostTrackingPoints(Long l) {
        return !this.zeroAllocations.containsKey(l) ? new HashSet() : this.zeroAllocations.get(l).keySet();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void purgeDeviceObject(Long l, Integer num, Long l2, AllocationPoint allocationPoint, boolean z) {
        if (allocationPoint.getAllocationStatus() != AllocationStatus.DEVICE) {
            return;
        }
        this.flowController.waitTillReleased(allocationPoint);
        free(allocationPoint, AllocationStatus.DEVICE);
        if (!this.deviceAllocations.get(num.intValue()).containsKey(l2)) {
            throw new IllegalStateException("Can't happen ever");
        }
        forget(allocationPoint, AllocationStatus.DEVICE);
        if (this.deviceAllocations.get(num.intValue()).containsKey(l2)) {
            throw new IllegalStateException("Can't happen ever");
        }
        allocationPoint.setAllocationStatus(AllocationStatus.HOST);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void purgeZeroObject(Long l, Long l2, AllocationPoint allocationPoint, boolean z) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void forget(AllocationPoint allocationPoint, AllocationStatus allocationStatus) {
        if (allocationStatus == AllocationStatus.DEVICE) {
            this.deviceAllocations.get(allocationPoint.getDeviceId()).remove(allocationPoint.getObjectId());
        } else {
            if (allocationStatus != AllocationStatus.HOST || allocationPoint.getHostPointer() == null) {
                return;
            }
            this.zeroAllocations.get(allocationPoint.getBucketId()).remove(allocationPoint.getObjectId());
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Integer getDeviceId() {
        return Integer.valueOf(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue());
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Pointer getDeviceIdPointer() {
        return new CudaPointer(getDeviceId().intValue());
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Set<Integer> getAvailableDevices() {
        return new HashSet(configuration.getAvailableDevices());
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public CudaContext getDeviceContext() {
        return getCudaContext();
    }

    protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext opaqueLaunchContext) {
        Integer deviceForCurrentThread = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        try {
            this.lock.writeLock().lock();
            if (this.cublasHandles.get(deviceForCurrentThread.intValue()) == null) {
                this.cublasHandles.remove(deviceForCurrentThread);
                this.cublasHandles.add(deviceForCurrentThread.intValue(), new cublasHandle_t(this.nativeOps.lcBlasHandle(opaqueLaunchContext)));
            }
            cublasHandle_t cublashandle_t = this.cublasHandles.get(deviceForCurrentThread.intValue());
            this.lock.writeLock().unlock();
            return cublashandle_t;
        } catch (Throwable th) {
            this.lock.writeLock().unlock();
            throw th;
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public CudaContext getCudaContext() {
        CudaContext cudaContext = this.tlContext.get();
        if (cudaContext != null) {
            return cudaContext;
        }
        OpaqueLaunchContext defaultLaunchContext = this.nativeOps.defaultLaunchContext();
        CudaContext build = CudaContext.builder().bufferScalar(this.nativeOps.lcScalarPointer(defaultLaunchContext)).bufferReduction(this.nativeOps.lcReductionPointer(defaultLaunchContext)).bufferAllocation(this.nativeOps.lcAllocationPointer(defaultLaunchContext)).bufferSpecial(this.nativeOps.lcScalarPointer(defaultLaunchContext)).oldStream(new cudaStream_t(this.nativeOps.lcExecutionStream(defaultLaunchContext))).specialStream(new cudaStream_t(this.nativeOps.lcCopyStream(defaultLaunchContext))).cublasHandle(getCudaCublasHandle(defaultLaunchContext)).solverHandle(new cusolverDnHandle_t(this.nativeOps.lcSolverHandle(defaultLaunchContext))).build();
        this.tlContext.set(build);
        return build;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void resetCachedContext() {
        this.tlContext.remove();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public boolean isDeviceDependant() {
        return true;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void synchronizeThreadDevice(Long l, Integer num, AllocationPoint allocationPoint) {
        this.flowController.synchronizeToHost(allocationPoint);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void registerAction(CudaContext cudaContext, INDArray iNDArray, INDArray... iNDArrayArr) {
        this.flowController.registerAction(cudaContext, iNDArray, iNDArrayArr);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public FlowController getFlowController() {
        return this.flowController;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public MemoryProvider getMemoryProvider() {
        return null;
    }
}
