package org.nd4j.jita.allocator.utils;

import lombok.NonNull;
import org.bytedeco.javacpp.LongPointer;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;

/* loaded from: input_file:org/nd4j/jita/allocator/utils/AllocationUtils.class */
public class AllocationUtils {
    public static long getRequiredMemory(@NonNull AllocationShape allocationShape) {
        if (allocationShape == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        return allocationShape.getLength() * getElementSize(allocationShape);
    }

    public static int getElementSize(@NonNull AllocationShape allocationShape) {
        if (allocationShape == null) {
            throw new NullPointerException("shape is marked @NonNull but is null");
        }
        return allocationShape.getElementSize() > 0 ? allocationShape.getElementSize() : Nd4j.sizeOfDataType(allocationShape.getDataType());
    }

    public static AllocationShape buildAllocationShape(INDArray iNDArray) {
        AllocationShape allocationShape = new AllocationShape();
        allocationShape.setDataType(iNDArray.data().dataType());
        allocationShape.setLength(iNDArray.length());
        allocationShape.setDataType(iNDArray.data().dataType());
        return allocationShape;
    }

    public static AllocationShape buildAllocationShape(DataBuffer dataBuffer) {
        AllocationShape allocationShape = new AllocationShape();
        allocationShape.setDataType(dataBuffer.dataType());
        allocationShape.setLength(dataBuffer.length());
        return allocationShape;
    }

    public static AllocationShape buildAllocationShape(JCudaBuffer jCudaBuffer) {
        AllocationShape allocationShape = new AllocationShape();
        allocationShape.setDataType(jCudaBuffer.dataType());
        allocationShape.setLength(jCudaBuffer.length());
        return allocationShape;
    }

    public static DataBuffer getPointersBuffer(long[] jArr) {
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(jArr.length);
        AtomicAllocator.getInstance().memcpyBlocking(cudaDoubleDataBuffer, new LongPointer(jArr), jArr.length * 8, 0L);
        return cudaDoubleDataBuffer;
    }
}
