package org.nd4j.linalg.jcublas;

import java.io.ObjectStreamException;
import java.util.List;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.JvmShapeInfo;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.nativeblas.NativeOpsHolder;

/* loaded from: input_file:org/nd4j/linalg/jcublas/JCublasNDArray.class */
public class JCublasNDArray extends BaseNDArray {
    public JCublasNDArray(DataBuffer dataBuffer, CudaLongDataBuffer cudaLongDataBuffer, long[] jArr) {
        this.jvmShapeInfo = new JvmShapeInfo(jArr);
        this.shapeInformation = cudaLongDataBuffer;
        this.data = dataBuffer;
    }

    public JCublasNDArray(double[][] dArr) {
        super(dArr);
    }

    public JCublasNDArray(double[][] dArr, char c) {
        super(dArr, c);
    }

    public JCublasNDArray(int[] iArr, DataBuffer dataBuffer) {
        super(iArr, dataBuffer);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, char c) {
        super(fArr, iArr, c);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, long j, char c) {
        super(fArr, iArr, j, c);
    }

    public JCublasNDArray(int[] iArr, int[] iArr2, long j, char c) {
        super(iArr, iArr2, j, c);
    }

    public JCublasNDArray(int[] iArr, int[] iArr2, long j, char c, boolean z) {
        super(iArr, iArr2, j, c, z);
    }

    public JCublasNDArray(long[] jArr, long[] jArr2, long j, char c, boolean z) {
        super(jArr, jArr2, j, c, z);
    }

    public JCublasNDArray(int[] iArr, int[] iArr2, char c) {
        super(iArr, iArr2, c);
    }

    public JCublasNDArray(int[] iArr, long j, char c) {
        super(iArr, j, c);
    }

    public JCublasNDArray(long[] jArr, long j, char c) {
        super(jArr, j, c);
    }

    public JCublasNDArray(int[] iArr) {
        super(iArr);
    }

    public JCublasNDArray(long[] jArr) {
        super(jArr);
    }

    public JCublasNDArray(int i, int i2, char c) {
        super(i, i2, c);
    }

    public JCublasNDArray(List<INDArray> list, int[] iArr, char c) {
        super(list, iArr, c);
    }

    public JCublasNDArray(List<INDArray> list, long[] jArr, char c) {
        super(list, jArr, c);
    }

    public JCublasNDArray(List<INDArray> list, int[] iArr, int[] iArr2, char c) {
        super(list, iArr, iArr2, c);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int[] iArr2, char c) {
        super(fArr, iArr, iArr2, c);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int[] iArr2, long j, char c) {
        super(fArr, iArr, iArr2, j, c);
    }

    public JCublasNDArray(float[] fArr, long[] jArr, long[] jArr2, long j, char c) {
        super(fArr, jArr, jArr2, j, c);
    }

    public JCublasNDArray(double[] dArr, long[] jArr, long[] jArr2, long j, char c) {
        super(dArr, jArr, jArr2, j, c);
    }

    public JCublasNDArray(int[] iArr, int[] iArr2, int[] iArr3) {
        super(iArr, iArr2, iArr3);
    }

    public JCublasNDArray(DataBuffer dataBuffer, int[] iArr) {
        super(dataBuffer, iArr);
    }

    public JCublasNDArray(DataBuffer dataBuffer, long[] jArr) {
        super(dataBuffer, jArr);
    }

    public JCublasNDArray(DataBuffer dataBuffer, int[] iArr, long j) {
        super(dataBuffer, iArr, j);
    }

    public JCublasNDArray(float[] fArr, int[] iArr) {
        super(fArr, iArr);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, long j) {
        super(fArr, iArr, j);
    }

    public JCublasNDArray(int[] iArr, int[] iArr2, long j) {
        super(iArr, iArr2, j);
    }

    public JCublasNDArray(int[] iArr, int[] iArr2) {
        super(iArr, iArr2);
    }

    public JCublasNDArray(int[] iArr, long j) {
        super(iArr, j);
    }

    public JCublasNDArray(int[] iArr, char c) {
        super(iArr, c);
    }

    public JCublasNDArray(int i, int i2) {
        super(i, i2);
    }

    public JCublasNDArray(List<INDArray> list, int[] iArr) {
        super(list, iArr);
    }

    public JCublasNDArray(List<INDArray> list, long[] jArr) {
        super(list, jArr);
    }

    public JCublasNDArray(List<INDArray> list, int[] iArr, int[] iArr2) {
        super(list, iArr, iArr2);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int[] iArr2) {
        super(fArr, iArr, iArr2);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int[] iArr2, long j) {
        super(fArr, iArr, iArr2, j);
    }

    public JCublasNDArray(float[] fArr) {
        super(fArr);
    }

    public JCublasNDArray(JCublasNDArray jCublasNDArray) {
        this(new long[]{jCublasNDArray.rows(), jCublasNDArray.columns()});
        this.data = dup().data();
    }

    public JCublasNDArray(double[] dArr, int[] iArr, int[] iArr2, long j) {
        super(dArr, iArr, iArr2, j);
    }

    public JCublasNDArray(float[][] fArr) {
        super(fArr);
    }

    public JCublasNDArray(float[][] fArr, char c) {
        super(fArr, c);
    }

    public JCublasNDArray(DataBuffer dataBuffer, int[] iArr, long j, char c) {
        super(dataBuffer, iArr, j, c);
    }

    public JCublasNDArray() {
    }

    public JCublasNDArray(DataBuffer dataBuffer) {
        super(dataBuffer);
    }

    public JCublasNDArray(DataBuffer dataBuffer, int[] iArr, int[] iArr2, long j, char c) {
        super(dataBuffer, iArr, iArr2, j, c);
    }

    public JCublasNDArray(DataBuffer dataBuffer, long[] jArr, long[] jArr2, long j, char c, DataType dataType) {
        super(dataBuffer, jArr, jArr2, j, c, dataType);
    }

    public JCublasNDArray(DataBuffer dataBuffer, long[] jArr, long[] jArr2, long j, long j2, char c, DataType dataType) {
        super(dataBuffer, jArr, jArr2, j, j2, c, dataType);
    }

    public JCublasNDArray(DataBuffer dataBuffer, long[] jArr, long[] jArr2, char c, DataType dataType) {
        super(dataBuffer, jArr, jArr2, c, dataType);
    }

    public JCublasNDArray(float[] fArr, char c) {
        super(fArr, c);
    }

    public JCublasNDArray(FloatBuffer floatBuffer, char c) {
        super(floatBuffer, c);
    }

    public JCublasNDArray(DataBuffer dataBuffer, int[] iArr, int[] iArr2) {
        super(dataBuffer, iArr, iArr2);
    }

    public JCublasNDArray(double[] dArr, int[] iArr, char c) {
        super(dArr, iArr, c);
    }

    public JCublasNDArray(double[] dArr, long[] jArr, char c) {
        super(dArr, jArr, c);
    }

    public JCublasNDArray(float[] fArr, long[] jArr, char c) {
        super(fArr, jArr, c);
    }

    public JCublasNDArray(double[] dArr, int[] iArr, int[] iArr2, long j, char c) {
        super(dArr, iArr, iArr2, j, c);
    }

    public INDArray dup() {
        if (isCompressed() && ordering() == Nd4j.order().charValue()) {
            INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(data().dup(), shapeInfoDataBuffer());
            createArrayFromShapeBuffer.markAsCompressed(true);
            return createArrayFromShapeBuffer;
        }
        INDArray dup = super.dup();
        Nd4j.getExecutioner().commit();
        return dup;
    }

    public INDArray dup(char c) {
        if (!isCompressed() || ordering() != c) {
            return super.dup(c);
        }
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(data().dup(), shapeInfoDataBuffer());
        createArrayFromShapeBuffer.markAsCompressed(true);
        return createArrayFromShapeBuffer;
    }

    public boolean equals(Object obj) {
        return super.equals(obj);
    }

    public String toString() {
        if (!isS()) {
            AtomicAllocator.getInstance().synchronizeHostData((INDArray) this);
        }
        return super.toString();
    }

    public void setShapeInfoDataBuffer(DataBuffer dataBuffer) {
        this.shapeInformation = dataBuffer;
        this.jvmShapeInfo = new JvmShapeInfo(this.shapeInformation.asLong());
    }

    private Object writeReplace() throws ObjectStreamException {
        return new BaseNDArrayProxy(this);
    }

    public INDArray permutei(int... iArr) {
        Nd4j.getExecutioner().push();
        return super.permutei(iArr);
    }

    public LongShapeDescriptor shapeDescriptor() {
        return LongShapeDescriptor.fromShape(shape(), stride(), elementWiseStride(), ordering(), dataType(), isEmpty());
    }

    public INDArray unsafeDuplication() {
        return unsafeDuplication(true);
    }

    public INDArray unsafeDuplication(boolean z) {
        WorkspaceUtils.assertValidArray(this, "Cannot duplicate array");
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createSame(this.data, false) : Nd4j.getDataBufferFactory().createSame(this.data, false, Nd4j.getMemoryManager().getCurrentWorkspace()), shapeInfoDataBuffer());
        if (z) {
            Nd4j.getExecutioner().push();
        }
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext cudaContext = (CudaContext) atomicAllocator.getDeviceContext().getContext();
        AllocationPoint allocationPoint = atomicAllocator.getAllocationPoint((INDArray) this);
        AllocationPoint allocationPoint2 = atomicAllocator.getAllocationPoint(createArrayFromShapeBuffer);
        MemcpyDirection memcpyDirection = MemcpyDirection.HOST_TO_HOST;
        long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
        if (allocationPoint2.getAllocationStatus() == AllocationStatus.DEVICE && allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint2.getDevicePointer(), allocationPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, z ? cudaContext.getOldStream() : cudaContext.getSpecialStream());
            allocationPoint2.tickDeviceWrite();
            memcpyDirection = MemcpyDirection.DEVICE_TO_DEVICE;
        } else if (allocationPoint2.getAllocationStatus() == AllocationStatus.HOST && allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint2.getHostPointer(), allocationPoint.getDevicePointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToHost, z ? cudaContext.getOldStream() : cudaContext.getSpecialStream());
            allocationPoint2.tickHostWrite();
            memcpyDirection = MemcpyDirection.DEVICE_TO_HOST;
        } else if (allocationPoint2.getAllocationStatus() == AllocationStatus.DEVICE && allocationPoint.getAllocationStatus() == AllocationStatus.HOST) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint2.getDevicePointer(), allocationPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, z ? cudaContext.getOldStream() : cudaContext.getSpecialStream());
            allocationPoint2.tickDeviceWrite();
            memcpyDirection = MemcpyDirection.HOST_TO_DEVICE;
        } else {
            NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint2.getHostPointer(), allocationPoint.getHostPointer(), this.data.length() * this.data.getElementSize(), CudaConstants.cudaMemcpyHostToHost, z ? cudaContext.getOldStream() : cudaContext.getSpecialStream());
            allocationPoint2.tickHostWrite();
        }
        if (z) {
            cudaContext.syncOldStream();
        } else {
            cudaContext.syncSpecialStream();
        }
        PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint2.getDeviceId(), helperStartTransaction, allocationPoint2.getNumberOfBytes(), memcpyDirection);
        return createArrayFromShapeBuffer;
    }

    public INDArray leverageTo(String str) {
        INDArray dup;
        if (isAttached() && Nd4j.getWorkspaceManager().checkIfWorkspaceExists(str)) {
            WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace");
            MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
            MemoryWorkspace workspaceForCurrentThread = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(str);
            if (currentWorkspace != workspaceForCurrentThread && this.data.getParentWorkspace() != workspaceForCurrentThread) {
                Nd4j.getMemoryManager().setCurrentWorkspace(workspaceForCurrentThread);
                if (isView()) {
                    dup = dup(ordering());
                    Nd4j.getExecutioner().commit();
                } else {
                    Nd4j.getExecutioner().commit();
                    DataBuffer createBuffer = Nd4j.createBuffer(lengthLong(), false);
                    AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(createBuffer);
                    AllocationPoint allocationPoint2 = AtomicAllocator.getInstance().getAllocationPoint(this.data);
                    CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(allocationPoint, allocationPoint2);
                    MemcpyDirection memcpyDirection = MemcpyDirection.DEVICE_TO_DEVICE;
                    long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
                    if (allocationPoint2.isActualOnDeviceSide()) {
                        if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint2.getDevicePointer(), lengthLong() * Nd4j.sizeOfDataType(createBuffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, prepareAction.getOldStream()) == 0) {
                            throw new ND4JIllegalStateException("memcpyAsync failed");
                        }
                    } else {
                        if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint2.getHostPointer(), lengthLong() * Nd4j.sizeOfDataType(createBuffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, prepareAction.getOldStream()) == 0) {
                            throw new ND4JIllegalStateException("memcpyAsync failed");
                        }
                        MemcpyDirection memcpyDirection2 = MemcpyDirection.HOST_TO_DEVICE;
                    }
                    prepareAction.syncOldStream();
                    PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), helperStartTransaction, allocationPoint2.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
                    dup = Nd4j.createArrayFromShapeBuffer(createBuffer, shapeInfoDataBuffer());
                    allocationPoint.tickHostRead();
                    allocationPoint.tickDeviceWrite();
                    AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, allocationPoint, allocationPoint2);
                }
                Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
                return dup;
            }
            return this;
        }
        return this;
    }

    public INDArray migrate() {
        INDArray dup;
        WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace");
        if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) {
            return this;
        }
        if (isView()) {
            dup = dup(ordering());
        } else {
            Nd4j.getExecutioner().commit();
            DataBuffer createBuffer = Nd4j.createBuffer(lengthLong(), false);
            AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(createBuffer);
            AllocationPoint allocationPoint2 = AtomicAllocator.getInstance().getAllocationPoint(this.data);
            CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(allocationPoint, allocationPoint2);
            MemcpyDirection memcpyDirection = MemcpyDirection.DEVICE_TO_DEVICE;
            long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
            if (allocationPoint2.isActualOnDeviceSide()) {
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint2.getDevicePointer(), lengthLong() * Nd4j.sizeOfDataType(createBuffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, prepareAction.getOldStream()) == 0) {
                    throw new ND4JIllegalStateException("memcpyAsync failed");
                }
            } else {
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint2.getHostPointer(), lengthLong() * Nd4j.sizeOfDataType(createBuffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, prepareAction.getOldStream()) == 0) {
                    throw new ND4JIllegalStateException("memcpyAsync failed");
                }
                memcpyDirection = MemcpyDirection.HOST_TO_DEVICE;
            }
            prepareAction.syncOldStream();
            PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), helperStartTransaction, allocationPoint.getNumberOfBytes(), memcpyDirection);
            if (allocationPoint.getDeviceId() != Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId()) {
                allocationPoint.setDeviceId(Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId());
            }
            dup = Nd4j.createArrayFromShapeBuffer(createBuffer, shapeInfoDataBuffer());
            allocationPoint.tickHostRead();
            allocationPoint.tickDeviceWrite();
            AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, allocationPoint, allocationPoint2);
        }
        return dup;
    }
}
