package org.nd4j.jita.allocator.tad;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/allocator/tad/BasicTADManager.class */
public class BasicTADManager implements TADManager {
    private static Logger logger = LoggerFactory.getLogger(BasicTADManager.class);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected AtomicLong bytes = new AtomicLong(0);

    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray iNDArray, int[] iArr) {
        if (iArr != null && iArr.length > 1) {
            Arrays.sort(iArr);
        }
        if (iArr == null) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        boolean z = iArr == null || (iArr.length == 1 && iArr[0] == Integer.MAX_VALUE);
        int rank = z ? 2 : iNDArray.rank();
        long j = 1;
        if (!z) {
            for (int i : iArr) {
                j *= iNDArray.shape()[i];
            }
        }
        long lengthLong = !z ? iNDArray.lengthLong() / j : 1L;
        CudaIntDataBuffer cudaIntDataBuffer = new CudaIntDataBuffer((rank * 2) + 4);
        CudaLongDataBuffer cudaLongDataBuffer = new CudaLongDataBuffer(lengthLong);
        AtomicAllocator.getInstance().getAllocationPoint(cudaIntDataBuffer).tickHostWrite();
        AtomicAllocator.getInstance().getAllocationPoint(cudaLongDataBuffer).tickHostWrite();
        IntPointer hostPointer = AtomicAllocator.getInstance().getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr));
        IntPointer retrieveHostPointer = AddressRetriever.retrieveHostPointer(iNDArray.shapeInfoDataBuffer());
        IntPointer retrieveHostPointer2 = AddressRetriever.retrieveHostPointer(cudaIntDataBuffer);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer(cudaLongDataBuffer);
        if (z) {
            cudaIntDataBuffer.put(0L, 2);
            cudaIntDataBuffer.put(1L, 1);
            cudaIntDataBuffer.put(2L, 1);
            cudaIntDataBuffer.put(3L, 1);
            cudaIntDataBuffer.put(4L, 1);
            cudaIntDataBuffer.put(5L, 0);
            cudaIntDataBuffer.put(6L, 0);
            cudaIntDataBuffer.put(7L, 99);
        } else {
            this.nativeOps.tadOnlyShapeInfo(retrieveHostPointer, hostPointer, iArr.length, retrieveHostPointer2, new LongPointerWrapper(retrieveHostPointer3));
        }
        AtomicAllocator.getInstance().getAllocationPoint(cudaIntDataBuffer).tickHostWrite();
        AtomicAllocator.getInstance().getAllocationPoint(cudaLongDataBuffer).tickHostWrite();
        return new Pair<>(cudaIntDataBuffer, cudaLongDataBuffer);
    }

    public void purgeBuffers() {
    }

    public long getCachedBytes() {
        return this.bytes.get();
    }
}
