package org.nd4j.linalg.jcublas.buffer;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.Collection;
import lombok.NonNull;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.indexer.Bfloat16Indexer;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.ShortIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UIntIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.impl.CudaDeallocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.memory.MemcpyDirection;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.bindings.Nd4jCuda;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.LongUtils;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.class */
public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCudaBuffer, Deallocatable {
    protected OpaqueDataBuffer ptrDataBuffer;
    protected volatile transient AllocationPoint allocationPoint;
    private static AtomicAllocator allocator;
    private static Logger log;
    protected DataType globalType;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.HALF.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.LONG.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.INT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.SHORT.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UBYTE.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BYTE.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BOOL.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UTF8.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BFLOAT16.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT16.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT32.ordinal()] = 13;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT64.ordinal()] = 14;
            } catch (NoSuchFieldError e14) {
            }
        }
    }

    public BaseCudaDataBuffer() {
        this.globalType = DataTypeUtil.getDtypeFromContext();
    }

    public OpaqueDataBuffer getOpaqueDataBuffer() {
        if (this.released) {
            throw new IllegalStateException("You can't use DataBuffer once it was released");
        }
        return this.ptrDataBuffer;
    }

    public BaseCudaDataBuffer(@NonNull Pointer pointer, @NonNull Pointer pointer2, @NonNull Indexer indexer, long j) {
        this.globalType = DataTypeUtil.getDtypeFromContext();
        if (pointer == null) {
            throw new NullPointerException("pointer is marked non-null but is null");
        }
        if (pointer2 == null) {
            throw new NullPointerException("specialPointer is marked non-null but is null");
        }
        if (indexer == null) {
            throw new NullPointerException("indexer is marked non-null but is null");
        }
        this.allocationMode = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
        this.indexer = indexer;
        this.offset = 0L;
        this.originalOffset = 0L;
        this.underlyingLength = j;
        this.length = j;
        initTypeAndSize();
        this.ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(j, this.type, pointer, pointer2);
        this.allocationPoint = new AllocationPoint(this.ptrDataBuffer, this.type.width() * j);
        Nd4j.getDeallocatorService().pickObject(this);
        if (this.released) {
            throw new IllegalStateException("You can't use DataBuffer once it was released");
        }
    }

    public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long j) {
        super(pointer, indexer, j);
        this.globalType = DataTypeUtil.getDtypeFromContext();
        this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(j, this.type, false);
        this.ptrDataBuffer.setPrimaryBuffer(pointer, j);
        this.allocationPoint = new AllocationPoint(this.ptrDataBuffer, j * this.elementSize);
        Nd4j.getDeallocatorService().pickObject(this);
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
        NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(this.allocationPoint.getDevicePointer(), pointer, j * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, deviceContext.getSpecialStream());
        PerformanceTracker.getInstance().helperRegisterTransaction(this.allocationPoint.getDeviceId(), helperStartTransaction / 2, this.allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
        deviceContext.getSpecialStream().synchronize();
    }

    public BaseCudaDataBuffer(float[] fArr, boolean z) {
        this(fArr, z, 0L);
    }

    public BaseCudaDataBuffer(float[] fArr, boolean z, MemoryWorkspace memoryWorkspace) {
        this(fArr, z, 0L, memoryWorkspace);
    }

    public BaseCudaDataBuffer(float[] fArr, boolean z, long j) {
        this(fArr.length, 4, false);
        this.offset = j;
        this.originalOffset = j;
        this.length = fArr.length - j;
        this.underlyingLength = fArr.length;
        set(fArr, this.length, j, j);
    }

    public BaseCudaDataBuffer(double[] dArr, boolean z, long j, MemoryWorkspace memoryWorkspace) {
        this(dArr.length, 8, false, memoryWorkspace);
        this.offset = j;
        this.originalOffset = j;
        this.length = dArr.length - j;
        this.underlyingLength = dArr.length;
        set(dArr, this.length, j, j);
    }

    public BaseCudaDataBuffer(float[] fArr, boolean z, long j, MemoryWorkspace memoryWorkspace) {
        this(fArr.length, 4, false, memoryWorkspace);
        this.offset = j;
        this.originalOffset = j;
        this.length = fArr.length - j;
        this.underlyingLength = fArr.length;
        set(fArr, this.length, j, j);
    }

    public BaseCudaDataBuffer(double[] dArr, boolean z) {
        this(dArr, z, 0L);
    }

    public BaseCudaDataBuffer(double[] dArr, boolean z, long j) {
        this(dArr.length, 8, false);
        this.offset = j;
        this.originalOffset = j;
        this.length = dArr.length - j;
        this.underlyingLength = dArr.length;
        set(dArr, this.length, j, j);
    }

    public BaseCudaDataBuffer(int[] iArr, boolean z) {
        this(iArr, z, 0L);
    }

    public BaseCudaDataBuffer(int[] iArr, boolean z, MemoryWorkspace memoryWorkspace) {
        this(iArr, z, 0L, memoryWorkspace);
    }

    public BaseCudaDataBuffer(int[] iArr, boolean z, long j) {
        this(iArr.length, 4, false);
        this.offset = j;
        this.originalOffset = j;
        this.length = iArr.length - j;
        this.underlyingLength = iArr.length;
        set(iArr, this.length, j, j);
    }

    public BaseCudaDataBuffer(int[] iArr, boolean z, long j, MemoryWorkspace memoryWorkspace) {
        this(iArr.length, 4, false, memoryWorkspace);
        this.offset = j;
        this.originalOffset = j;
        this.length = iArr.length - j;
        this.underlyingLength = iArr.length;
        set(iArr, this.length, j, j);
    }

    protected void initPointers(long j, DataType dataType, boolean z) {
        initPointers(j, Nd4j.sizeOfDataType(dataType), z);
    }

    public void lazyAllocateHostPointer() {
        if (length() == 0) {
            return;
        }
        if (this.indexer == null || this.pointer == null || this.pointer.address() == 0) {
            initHostPointerAndIndexer();
        } else {
            if (this.allocationPoint.getHostPointer() == null || this.allocationPoint.getHostPointer().address() == this.pointer.address()) {
                return;
            }
            initHostPointerAndIndexer();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseCudaDataBuffer(ByteBuffer byteBuffer, DataType dataType, long j, long j2) {
        this(j, Nd4j.sizeOfDataType(dataType));
        DoublePointer doublePointer = null;
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType().ordinal()]) {
            case 1:
                doublePointer = new DoublePointer(byteBuffer.asDoubleBuffer());
                break;
            case 2:
                doublePointer = new FloatPointer(byteBuffer.asFloatBuffer());
                break;
            case 3:
                doublePointer = new ShortPointer(byteBuffer.asShortBuffer());
                break;
            case 4:
                doublePointer = new LongPointer(byteBuffer.asLongBuffer());
                break;
            case Nd4jCuda.FLOAT32 /* 5 */:
                doublePointer = new IntPointer(byteBuffer.asIntBuffer());
                break;
            case Nd4jCuda.DOUBLE /* 6 */:
                doublePointer = new ShortPointer(byteBuffer.asShortBuffer());
                break;
            case Nd4jCuda.INT8 /* 7 */:
            case Nd4jCuda.INT16 /* 8 */:
                doublePointer = new BytePointer(byteBuffer);
                break;
            case Nd4jCuda.INT32 /* 9 */:
                doublePointer = new BooleanPointer(length());
                break;
            case 10:
                doublePointer = new BytePointer(length());
                break;
            case Nd4jCuda.UINT8 /* 11 */:
                doublePointer = new ShortPointer(length());
                break;
            case 12:
                doublePointer = new ShortPointer(length());
                break;
            case Nd4jCuda.UINT32 /* 13 */:
                doublePointer = new IntPointer(length());
                break;
            case Nd4jCuda.UINT64 /* 14 */:
                doublePointer = new LongPointer(length());
                break;
        }
        cudaStream_t specialStream = AtomicAllocator.getInstance().getDeviceContext().getSpecialStream();
        NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(this.ptrDataBuffer.specialBuffer(), j2 > 0 ? new PagedPointer(doublePointer.address() + (j2 * getElementSize())) : doublePointer, j * Nd4j.sizeOfDataType(dataType), CudaConstants.cudaMemcpyHostToDevice, specialStream);
        specialStream.synchronize();
        this.allocationPoint.tickDeviceWrite();
    }

    protected void initHostPointerAndIndexer() {
        if (length() == 0) {
            return;
        }
        if (this.allocationPoint.getHostPointer() == null) {
            AllocationStatus allocationStatus = this.allocationPoint.getAllocationStatus();
            if (this.parentWorkspace == null) {
                NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(this.ptrDataBuffer);
            } else {
                this.ptrDataBuffer.setPrimaryBuffer(this.parentWorkspace.alloc(this.length * this.elementSize, MemoryKind.HOST, dataType(), false), this.length);
            }
            this.allocationPoint.setAllocationStatus(allocationStatus);
            this.allocationPoint.tickDeviceWrite();
        }
        Pointer hostPointer = this.allocationPoint.getHostPointer();
        if (!$assertionsDisabled && hostPointer == null) {
            throw new AssertionError();
        }
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType().ordinal()]) {
            case 1:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asDoublePointer();
                this.indexer = DoubleIndexer.create(this.pointer);
                return;
            case 2:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asFloatPointer();
                this.indexer = FloatIndexer.create(this.pointer);
                return;
            case 3:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asShortPointer();
                this.indexer = HalfIndexer.create(this.pointer);
                return;
            case 4:
            case Nd4jCuda.UINT64 /* 14 */:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asLongPointer();
                this.indexer = LongIndexer.create(this.pointer);
                return;
            case Nd4jCuda.FLOAT32 /* 5 */:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asIntPointer();
                this.indexer = IntIndexer.create(this.pointer);
                return;
            case Nd4jCuda.DOUBLE /* 6 */:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asShortPointer();
                this.indexer = ShortIndexer.create(this.pointer);
                return;
            case Nd4jCuda.INT8 /* 7 */:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asBytePointer();
                this.indexer = UByteIndexer.create(this.pointer);
                return;
            case Nd4jCuda.INT16 /* 8 */:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asBytePointer();
                this.indexer = ByteIndexer.create(this.pointer);
                return;
            case Nd4jCuda.INT32 /* 9 */:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asBooleanPointer();
                this.indexer = BooleanIndexer.create(this.pointer);
                return;
            case 10:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asBytePointer();
                this.indexer = ByteIndexer.create(this.pointer);
                return;
            case Nd4jCuda.UINT8 /* 11 */:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asShortPointer();
                this.indexer = Bfloat16Indexer.create(this.pointer);
                return;
            case 12:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asShortPointer();
                this.indexer = UShortIndexer.create(this.pointer);
                return;
            case Nd4jCuda.UINT32 /* 13 */:
                this.pointer = new CudaPointer(hostPointer, this.length, 0L).asIntPointer();
                this.indexer = UIntIndexer.create(this.pointer);
                return;
            default:
                throw new UnsupportedOperationException();
        }
    }

    protected void initPointers(long j, int i, boolean z) {
        this.allocationMode = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
        this.length = j;
        this.elementSize = (byte) i;
        this.offset = 0L;
        this.originalOffset = 0L;
        this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(j, this.type, false);
        this.allocationPoint = new AllocationPoint(this.ptrDataBuffer, j * this.type.width());
        if (z) {
            CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
            NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(this.allocationPoint.getDevicePointer(), 0, j * i, 0, deviceContext.getSpecialStream());
            deviceContext.getSpecialStream().synchronize();
        }
        Nd4j.getDeallocatorService().pickObject(this);
    }

    public BaseCudaDataBuffer(long j, int i, boolean z) {
        this.globalType = DataTypeUtil.getDtypeFromContext();
        initTypeAndSize();
        initPointers(j, i, z);
    }

    public BaseCudaDataBuffer(long j, int i, boolean z, @NonNull MemoryWorkspace memoryWorkspace) {
        this.globalType = DataTypeUtil.getDtypeFromContext();
        if (memoryWorkspace == null) {
            throw new NullPointerException("workspace is marked non-null but is null");
        }
        this.allocationMode = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
        initTypeAndSize();
        this.attached = true;
        this.parentWorkspace = memoryWorkspace;
        this.length = j;
        this.offset = 0L;
        this.originalOffset = 0L;
        if (memoryWorkspace.getWorkspaceConfiguration().getPolicyMirroring() == MirroringPolicy.FULL) {
            PagedPointer alloc = memoryWorkspace.alloc(j * i, MemoryKind.DEVICE, this.type, z);
            this.ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(this.length, this.type, (Pointer) null, alloc);
            if (z) {
                CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
                NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(alloc, 0, j * i, 0, deviceContext.getSpecialStream());
                deviceContext.getSpecialStream().synchronize();
            }
        } else {
            PagedPointer alloc2 = memoryWorkspace.alloc(j * i, MemoryKind.HOST, this.type, z);
            this.ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(this.length, this.type, (Pointer) null, alloc2);
            if (z) {
                CudaContext deviceContext2 = AtomicAllocator.getInstance().getDeviceContext();
                NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(alloc2, 0, j * i, 0, deviceContext2.getSpecialStream());
                deviceContext2.getSpecialStream().synchronize();
            }
        }
        this.allocationPoint = new AllocationPoint(this.ptrDataBuffer, i * j);
        Nd4j.getDeallocatorService().pickObject(this);
        this.workspaceGenerationId = memoryWorkspace.getGenerationId();
        this.attached = true;
        this.parentWorkspace = memoryWorkspace;
    }

    protected void setIndexer(Indexer indexer) {
        this.indexer = indexer;
    }

    public BaseCudaDataBuffer(long j, int i) {
        this(j, i, true);
    }

    public BaseCudaDataBuffer(long j, int i, MemoryWorkspace memoryWorkspace) {
        this(j, i, true, memoryWorkspace);
    }

    public BaseCudaDataBuffer(long j, int i, long j2) {
        this(j, i);
        this.offset = j2;
        this.originalOffset = j2;
    }

    public BaseCudaDataBuffer(@NonNull DataBuffer dataBuffer, long j, long j2) {
        this.globalType = DataTypeUtil.getDtypeFromContext();
        if (dataBuffer == null) {
            throw new NullPointerException("underlyingBuffer is marked non-null but is null");
        }
        if (dataBuffer.wasClosed()) {
            throw new IllegalStateException("You can't use DataBuffer once it was released");
        }
        this.allocationMode = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
        initTypeAndSize();
        this.wrappedDataBuffer = dataBuffer;
        this.originalBuffer = dataBuffer.originalDataBuffer() == null ? dataBuffer : dataBuffer.originalDataBuffer();
        this.length = j;
        this.offset = j2;
        this.originalOffset = j2;
        this.elementSize = (byte) dataBuffer.getElementSize();
        ((BaseCudaDataBuffer) dataBuffer).lazyAllocateHostPointer();
        this.ptrDataBuffer = ((BaseCudaDataBuffer) dataBuffer).ptrDataBuffer.createView(j * dataBuffer.getElementSize(), j2 * dataBuffer.getElementSize());
        this.allocationPoint = new AllocationPoint(this.ptrDataBuffer, j);
        Pointer hostPointer = this.allocationPoint.getHostPointer();
        Nd4j.getDeallocatorService().pickObject(this);
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataBuffer.dataType().ordinal()]) {
            case 1:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asDoublePointer();
                this.indexer = DoubleIndexer.create(this.pointer);
                return;
            case 2:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asFloatPointer();
                this.indexer = FloatIndexer.create(this.pointer);
                return;
            case 3:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asShortPointer();
                this.indexer = HalfIndexer.create(this.pointer);
                return;
            case 4:
            case Nd4jCuda.UINT64 /* 14 */:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asLongPointer();
                this.indexer = LongIndexer.create(this.pointer);
                return;
            case Nd4jCuda.FLOAT32 /* 5 */:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asIntPointer();
                this.indexer = IntIndexer.create(this.pointer);
                return;
            case Nd4jCuda.DOUBLE /* 6 */:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asShortPointer();
                this.indexer = ShortIndexer.create(this.pointer);
                return;
            case Nd4jCuda.INT8 /* 7 */:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asBytePointer();
                this.indexer = UByteIndexer.create(this.pointer);
                return;
            case Nd4jCuda.INT16 /* 8 */:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asBytePointer();
                this.indexer = ByteIndexer.create(this.pointer);
                return;
            case Nd4jCuda.INT32 /* 9 */:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asBooleanPointer();
                this.indexer = BooleanIndexer.create(this.pointer);
                return;
            case 10:
                Preconditions.checkArgument(j2 == 0, "String array can't be a view");
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asBytePointer();
                this.indexer = ByteIndexer.create(this.pointer);
                return;
            case Nd4jCuda.UINT8 /* 11 */:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asShortPointer();
                this.indexer = Bfloat16Indexer.create(this.pointer);
                return;
            case 12:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asShortPointer();
                this.indexer = UShortIndexer.create(this.pointer);
                return;
            case Nd4jCuda.UINT32 /* 13 */:
                this.pointer = new CudaPointer(hostPointer, this.originalBuffer.length()).asIntPointer();
                this.indexer = UIntIndexer.create(this.pointer);
                return;
            default:
                throw new UnsupportedOperationException();
        }
    }

    public BaseCudaDataBuffer(long j) {
        this(j, Nd4j.sizeOfDataType(Nd4j.dataType()));
    }

    public BaseCudaDataBuffer(float[] fArr) {
        this(fArr.length, Nd4j.sizeOfDataType(DataType.FLOAT), false);
        set(fArr, fArr.length, 0L, 0L);
    }

    public BaseCudaDataBuffer(int[] iArr) {
        this(iArr.length, Nd4j.sizeOfDataType(DataType.INT), false);
        set(iArr, iArr.length, 0L, 0L);
    }

    public BaseCudaDataBuffer(long[] jArr) {
        this(jArr.length, Nd4j.sizeOfDataType(DataType.LONG), false);
        set(jArr, jArr.length, 0L, 0L);
    }

    public BaseCudaDataBuffer(long[] jArr, boolean z) {
        this(jArr.length, Nd4j.sizeOfDataType(DataType.LONG), false);
        if (z) {
            set(jArr, jArr.length, 0L, 0L);
        }
    }

    public BaseCudaDataBuffer(double[] dArr) {
        this(dArr.length, Nd4j.sizeOfDataType(DataType.DOUBLE), false);
        set(dArr, dArr.length, 0L, 0L);
    }

    public long address() {
        if (this.released) {
            throw new IllegalStateException("You can't use DataBuffer once it was released");
        }
        return this.allocationPoint.getHostPointer().address();
    }

    public long platformAddress() {
        return this.allocationPoint.getDevicePointer().address();
    }

    public Pointer pointer() {
        if (this.released) {
            throw new IllegalStateException("You can't use DataBuffer once it was released");
        }
        lazyAllocateHostPointer();
        return super.pointer();
    }

    public void set(int[] iArr, long j, long j2, long j3) {
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType().ordinal()]) {
            case 1:
                DoublePointer doublePointer = new DoublePointer(ArrayUtil.toDouble(iArr));
                allocator.memcpyAsync(this, new CudaPointer(doublePointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                doublePointer.address();
                return;
            case 2:
                FloatPointer floatPointer = new FloatPointer(ArrayUtil.toFloats(iArr));
                allocator.memcpyAsync(this, new CudaPointer(floatPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                floatPointer.address();
                return;
            case 3:
                ShortPointer shortPointer = new ShortPointer(ArrayUtil.toHalfs(iArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer.address();
                return;
            case 4:
                LongPointer longPointer = new LongPointer(LongUtils.toLongs(iArr));
                allocator.memcpyAsync(this, new CudaPointer(longPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                longPointer.address();
                return;
            case Nd4jCuda.FLOAT32 /* 5 */:
                IntPointer intPointer = new IntPointer(iArr);
                allocator.memcpyAsync(this, new CudaPointer(intPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                intPointer.address();
                return;
            case Nd4jCuda.DOUBLE /* 6 */:
                ShortPointer shortPointer2 = new ShortPointer(ArrayUtil.toShorts(iArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer2.address();
                return;
            case Nd4jCuda.INT8 /* 7 */:
                for (int i = 0; i < iArr.length; i++) {
                    put(i, iArr[i]);
                }
                return;
            case Nd4jCuda.INT16 /* 8 */:
                BytePointer bytePointer = new BytePointer(ArrayUtil.toBytes(iArr));
                allocator.memcpyAsync(this, new CudaPointer(bytePointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                bytePointer.address();
                return;
            case Nd4jCuda.INT32 /* 9 */:
                BytePointer bytePointer2 = new BytePointer(ArrayUtil.toBytes(iArr));
                allocator.memcpyAsync(this, new CudaPointer(bytePointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                bytePointer2.address();
                return;
            default:
                throw new UnsupportedOperationException("Unsupported data type: " + dataType());
        }
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:2:0x000b. Please report as an issue. */
    public void set(long[] jArr, long j, long j2, long j3) {
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType().ordinal()]) {
            case 1:
                DoublePointer doublePointer = new DoublePointer(ArrayUtil.toDouble(jArr));
                allocator.memcpyAsync(this, new CudaPointer(doublePointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                doublePointer.address();
                return;
            case 2:
                FloatPointer floatPointer = new FloatPointer(ArrayUtil.toFloats(jArr));
                allocator.memcpyAsync(this, new CudaPointer(floatPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                floatPointer.address();
                return;
            case 3:
                ShortPointer shortPointer = new ShortPointer(ArrayUtil.toHalfs(jArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer.address();
                return;
            case 4:
                LongPointer longPointer = new LongPointer(jArr);
                allocator.memcpyAsync(this, new CudaPointer(longPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                longPointer.address();
                return;
            case Nd4jCuda.FLOAT32 /* 5 */:
                IntPointer intPointer = new IntPointer(ArrayUtil.toInts(jArr));
                allocator.memcpyAsync(this, new CudaPointer(intPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                intPointer.address();
                return;
            case Nd4jCuda.DOUBLE /* 6 */:
                ShortPointer shortPointer2 = new ShortPointer(ArrayUtil.toShorts(jArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer2.address();
                return;
            case Nd4jCuda.INT8 /* 7 */:
                long[] cutBelowZero = ArrayUtil.cutBelowZero(jArr);
                for (int i = 0; i < cutBelowZero.length; i++) {
                    put(i, cutBelowZero[i]);
                }
                return;
            case Nd4jCuda.INT16 /* 8 */:
                BytePointer bytePointer = new BytePointer(ArrayUtil.toBytes(jArr));
                allocator.memcpyAsync(this, new CudaPointer(bytePointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                bytePointer.address();
                return;
            case Nd4jCuda.INT32 /* 9 */:
                BytePointer bytePointer2 = new BytePointer(ArrayUtil.toBytes(jArr));
                allocator.memcpyAsync(this, new CudaPointer(bytePointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                bytePointer2.address();
                return;
            case 10:
            default:
                throw new UnsupportedOperationException("Unsupported data type: " + dataType());
            case Nd4jCuda.UINT8 /* 11 */:
                ShortPointer shortPointer3 = new ShortPointer(ArrayUtil.toBfloats(jArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer3.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer3.address();
                return;
            case 12:
                jArr = ArrayUtil.cutBelowZero(jArr);
                ShortPointer shortPointer22 = new ShortPointer(ArrayUtil.toShorts(jArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer22.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer22.address();
                return;
            case Nd4jCuda.UINT32 /* 13 */:
                jArr = ArrayUtil.cutBelowZero(jArr);
                IntPointer intPointer2 = new IntPointer(ArrayUtil.toInts(jArr));
                allocator.memcpyAsync(this, new CudaPointer(intPointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                intPointer2.address();
                return;
            case Nd4jCuda.UINT64 /* 14 */:
                jArr = ArrayUtil.cutBelowZero(jArr);
                LongPointer longPointer2 = new LongPointer(jArr);
                allocator.memcpyAsync(this, new CudaPointer(longPointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                longPointer2.address();
                return;
        }
    }

    public void set(float[] fArr, long j, long j2, long j3) {
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType().ordinal()]) {
            case 1:
                DoublePointer doublePointer = new DoublePointer(ArrayUtil.toDoubles(fArr));
                allocator.memcpyAsync(this, new CudaPointer(doublePointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                doublePointer.address();
                return;
            case 2:
                FloatPointer floatPointer = new FloatPointer(fArr);
                allocator.memcpyAsync(this, new CudaPointer(floatPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                floatPointer.address();
                return;
            case 3:
                ShortPointer shortPointer = new ShortPointer(ArrayUtil.toHalfs(fArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer.address();
                return;
            case 4:
                LongPointer longPointer = new LongPointer(ArrayUtil.toLongArray(fArr));
                allocator.memcpyAsync(this, new CudaPointer(longPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                longPointer.address();
                return;
            case Nd4jCuda.FLOAT32 /* 5 */:
                IntPointer intPointer = new IntPointer(ArrayUtil.toInts(fArr));
                allocator.memcpyAsync(this, new CudaPointer(intPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                intPointer.address();
                return;
            case Nd4jCuda.DOUBLE /* 6 */:
                ShortPointer shortPointer2 = new ShortPointer(ArrayUtil.toShorts(fArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer2.address();
                return;
            case Nd4jCuda.INT8 /* 7 */:
                for (int i = 0; i < fArr.length; i++) {
                    put(i, fArr[i]);
                }
                return;
            case Nd4jCuda.INT16 /* 8 */:
                BytePointer bytePointer = new BytePointer(ArrayUtil.toBytes(fArr));
                allocator.memcpyAsync(this, new CudaPointer(bytePointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                bytePointer.address();
                return;
            case Nd4jCuda.INT32 /* 9 */:
                BytePointer bytePointer2 = new BytePointer(ArrayUtil.toBytes(fArr));
                allocator.memcpyAsync(this, new CudaPointer(bytePointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                bytePointer2.address();
                return;
            default:
                throw new UnsupportedOperationException("Unsupported data type: " + dataType());
        }
    }

    public void set(double[] dArr, long j, long j2, long j3) {
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType().ordinal()]) {
            case 1:
                DoublePointer doublePointer = new DoublePointer(dArr);
                allocator.memcpyAsync(this, new CudaPointer(doublePointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                doublePointer.address();
                return;
            case 2:
                FloatPointer floatPointer = new FloatPointer(ArrayUtil.toFloats(dArr));
                allocator.memcpyAsync(this, new CudaPointer(floatPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                floatPointer.address();
                return;
            case 3:
                ShortPointer shortPointer = new ShortPointer(ArrayUtil.toHalfs(dArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer.address();
                return;
            case 4:
                LongPointer longPointer = new LongPointer(ArrayUtil.toLongs(dArr));
                allocator.memcpyAsync(this, new CudaPointer(longPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                longPointer.address();
                return;
            case Nd4jCuda.FLOAT32 /* 5 */:
                IntPointer intPointer = new IntPointer(ArrayUtil.toInts(dArr));
                allocator.memcpyAsync(this, new CudaPointer(intPointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                intPointer.address();
                return;
            case Nd4jCuda.DOUBLE /* 6 */:
                ShortPointer shortPointer2 = new ShortPointer(ArrayUtil.toShorts(dArr));
                allocator.memcpyAsync(this, new CudaPointer(shortPointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                shortPointer2.address();
                return;
            case Nd4jCuda.INT8 /* 7 */:
                for (int i = 0; i < dArr.length; i++) {
                    put(i, dArr[i]);
                }
                return;
            case Nd4jCuda.INT16 /* 8 */:
                BytePointer bytePointer = new BytePointer(ArrayUtil.toBytes(dArr));
                allocator.memcpyAsync(this, new CudaPointer(bytePointer.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                bytePointer.address();
                return;
            case Nd4jCuda.INT32 /* 9 */:
                BytePointer bytePointer2 = new BytePointer(ArrayUtil.toBytes(dArr));
                allocator.memcpyAsync(this, new CudaPointer(bytePointer2.address() + (j2 * this.elementSize)), j * this.elementSize, j3 * this.elementSize);
                bytePointer2.address();
                return;
            default:
                throw new UnsupportedOperationException("Unsupported data type: " + dataType());
        }
    }

    public void setData(int[] iArr) {
        if (iArr.length == 0) {
            return;
        }
        set(iArr, iArr.length, 0L, 0L);
    }

    public void setData(long[] jArr) {
        if (jArr.length == 0) {
            return;
        }
        set(jArr, jArr.length, 0L, 0L);
    }

    public void setData(float[] fArr) {
        if (fArr.length == 0) {
            return;
        }
        set(fArr, fArr.length, 0L, 0L);
    }

    public void setData(double[] dArr) {
        if (dArr.length == 0) {
            return;
        }
        set(dArr, dArr.length, 0L, 0L);
    }

    protected void setNioBuffer() {
        throw new UnsupportedOperationException("setNioBuffer() is not supported for CUDA backend");
    }

    public void copyAtStride(DataBuffer dataBuffer, long j, long j2, long j3, long j4, long j5) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        allocator.synchronizeHostData(dataBuffer);
        super.copyAtStride(dataBuffer, j, j2, j3, j4, j5);
    }

    public DataBuffer.AllocationMode allocationMode() {
        return this.allocationMode;
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public ByteBuffer getHostBuffer() {
        return this.pointer.asByteBuffer();
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public Pointer getHostPointer() {
        return AtomicAllocator.getInstance().getHostPointer(this);
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public Pointer getHostPointer(long j) {
        throw new UnsupportedOperationException();
    }

    public void removeReferencing(String str) {
    }

    public Collection<String> references() {
        return null;
    }

    public int getElementSize() {
        return this.elementSize;
    }

    public void addReferencing(String str) {
    }

    @Deprecated
    public Pointer getHostPointer(INDArray iNDArray, int i, long j, int i2) {
        throw new UnsupportedOperationException("This method is deprecated");
    }

    @Deprecated
    public void set(Pointer pointer) {
        throw new UnsupportedOperationException("set(Pointer) is not supported");
    }

    public void put(long j, float f) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        allocator.tickHostWrite(this);
        super.put(j, f);
    }

    public void put(long j, boolean z) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        allocator.tickHostWrite(this);
        super.put(j, z);
    }

    public void put(long j, double d) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        allocator.tickHostWrite(this);
        super.put(j, d);
    }

    public void put(long j, int i) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        allocator.tickHostWrite(this);
        super.put(j, i);
    }

    public void put(long j, long j2) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        allocator.tickHostWrite(this);
        super.put(j, j2);
    }

    public Pointer addressPointer() {
        if (this.released) {
            throw new IllegalStateException("You can't use DataBuffer once it was released");
        }
        return AtomicAllocator.getInstance().getHostPointer(this);
    }

    @Deprecated
    protected void set(long j, long j2, Pointer pointer, long j3) {
        long elementSize = getElementSize() * j;
        if (elementSize < length() * getElementSize()) {
            throw new UnsupportedOperationException("Deprecated set() call");
        }
        throw new IllegalArgumentException("Illegal offset " + elementSize + " with index of " + j + " and length " + length());
    }

    @Deprecated
    protected void set(long j, long j2, Pointer pointer) {
        set(j, j2, pointer, 1L);
    }

    public void assign(DataBuffer dataBuffer) {
        allocator.memcpy(this, dataBuffer);
    }

    public void assign(long[] jArr, float[] fArr, boolean z, long j) {
        if (jArr.length != fArr.length) {
            throw new IllegalArgumentException("Indices and data length must be the same");
        }
        if (jArr.length > length()) {
            throw new IllegalArgumentException("More elements than space to assign. This buffer is of length " + length() + " where the indices are of length " + fArr.length);
        }
        for (int i = 0; i < jArr.length; i++) {
            put(jArr[i], fArr[i]);
        }
    }

    public void assign(long[] jArr, double[] dArr, boolean z, long j) {
        if (jArr.length != dArr.length) {
            throw new IllegalArgumentException("Indices and data length must be the same");
        }
        if (jArr.length > length()) {
            throw new IllegalArgumentException("More elements than space to assign. This buffer is of length " + length() + " where the indices are of length " + dArr.length);
        }
        for (int i = 0; i < jArr.length; i++) {
            put(jArr[i], dArr[i]);
        }
    }

    @Deprecated
    protected void set(long j, Pointer pointer) {
        set(j, 1L, pointer);
    }

    public void flush() {
    }

    public void destroy() {
    }

    protected double getDoubleUnsynced(long j) {
        return super.getDouble(j);
    }

    protected float getFloatUnsynced(long j) {
        return super.getFloat(j);
    }

    protected long getLongUnsynced(long j) {
        return super.getLong(j);
    }

    protected int getIntUnsynced(long j) {
        return super.getInt(j);
    }

    public void write(DataOutputStream dataOutputStream) throws IOException {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        super.write(dataOutputStream);
    }

    public void write(OutputStream outputStream) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        super.write(outputStream);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        objectOutputStream.defaultWriteObject();
        write(objectOutputStream);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        doReadObject(objectInputStream);
    }

    public String toString() {
        lazyAllocateHostPointer();
        AtomicAllocator.getInstance().synchronizeHostData(this);
        return super.toString();
    }

    public boolean sameUnderlyingData(DataBuffer dataBuffer) {
        return this.ptrDataBuffer.address() == ((BaseCudaDataBuffer) dataBuffer).ptrDataBuffer.address();
    }

    public boolean equals(Object obj) {
        return obj != null && this == obj;
    }

    public void read(InputStream inputStream, DataBuffer.AllocationMode allocationMode, long j, DataType dataType) {
        if (this.allocationPoint == null) {
            initPointers(j, dataType, false);
        }
        super.read(inputStream, allocationMode, j, dataType);
        this.allocationPoint.tickHostWrite();
    }

    public void pointerIndexerByCurrentType(DataType dataType) {
    }

    public void read(DataInputStream dataInputStream) {
        try {
            DataBuffer.AllocationMode valueOf = DataBuffer.AllocationMode.valueOf(dataInputStream.readUTF());
            this.allocationMode = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
            long readInt = valueOf.ordinal() < 3 ? dataInputStream.readInt() : dataInputStream.readLong();
            boolean z = readInt != this.length || this.indexer == null;
            this.length = readInt;
            DataType valueOf2 = DataType.valueOf(dataInputStream.readUTF());
            if (this.globalType == null && Nd4j.dataType() != null) {
                this.globalType = Nd4j.dataType();
            }
            if (valueOf2 == DataType.COMPRESSED) {
                this.type = valueOf2;
                return;
            }
            this.elementSize = (byte) Nd4j.sizeOfDataType(valueOf2);
            this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(this.length, this.elementSize, valueOf2), false);
            this.type = valueOf2;
            Nd4j.getDeallocatorService().pickObject(this);
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[this.type.ordinal()]) {
                case 1:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asDoublePointer();
                    this.indexer = DoubleIndexer.create(this.pointer);
                    break;
                case 2:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asFloatPointer();
                    this.indexer = FloatIndexer.create(this.pointer);
                    break;
                case 3:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asShortPointer();
                    this.indexer = HalfIndexer.create(this.pointer);
                    break;
                case 4:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asLongPointer();
                    this.indexer = LongIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.FLOAT32 /* 5 */:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asIntPointer();
                    this.indexer = IntIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.DOUBLE /* 6 */:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asShortPointer();
                    this.indexer = ShortIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.INT8 /* 7 */:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asBytePointer();
                    this.indexer = UByteIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.INT16 /* 8 */:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asBytePointer();
                    this.indexer = ByteIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.INT32 /* 9 */:
                    this.pointer = new CudaPointer(this.allocationPoint.getHostPointer(), this.length).asBooleanPointer();
                    this.indexer = BooleanIndexer.create(this.pointer);
                    break;
                default:
                    throw new UnsupportedOperationException("Unsupported data type: " + this.type);
            }
            readContent(dataInputStream, valueOf2, valueOf2);
            this.allocationPoint.tickHostWrite();
            AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(this.allocationPoint);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public byte[] asBytes() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asBytes();
    }

    public double[] asDouble() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asDouble();
    }

    public float[] asFloat() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asFloat();
    }

    public int[] asInt() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asInt();
    }

    public long[] asLong() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asLong();
    }

    public ByteBuffer asNio() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asNio();
    }

    public DoubleBuffer asNioDouble() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asNioDouble();
    }

    public FloatBuffer asNioFloat() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asNioFloat();
    }

    public IntBuffer asNioInt() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.asNioInt();
    }

    public DataBuffer dup() {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        DataBuffer create = create(this.length);
        allocator.memcpyBlocking(create, new CudaPointer(allocator.getHostPointer(this).address()), this.length * this.elementSize, 0L);
        return create;
    }

    public Number getNumber(long j) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.getNumber(j);
    }

    public double getDouble(long j) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.getDouble(j);
    }

    public long getLong(long j) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.getLong(j);
    }

    public float getFloat(long j) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.getFloat(j);
    }

    public int getInt(long j) {
        lazyAllocateHostPointer();
        allocator.synchronizeHostData(this);
        return super.getInt(j);
    }

    public void actualizePointerAndIndexer() {
        Pointer primaryBuffer = this.ptrDataBuffer.primaryBuffer();
        if (primaryBuffer == null || this.pointer == null || primaryBuffer.address() != this.pointer.address()) {
            DataType dataType = dataType();
            if (dataType == DataType.BOOL) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asBoolPointer();
                setIndexer(BooleanIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.UBYTE) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asBytePointer();
                setIndexer(UByteIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.BYTE) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asBytePointer();
                setIndexer(ByteIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.UINT16) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asShortPointer();
                setIndexer(UShortIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.SHORT) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asShortPointer();
                setIndexer(ShortIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.UINT32) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asIntPointer();
                setIndexer(UIntIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.INT) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asIntPointer();
                setIndexer(IntIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.UINT64) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asLongPointer();
                setIndexer(LongIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.LONG) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asLongPointer();
                setIndexer(LongIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.BFLOAT16) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asShortPointer();
                setIndexer(Bfloat16Indexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.HALF) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asShortPointer();
                setIndexer(HalfIndexer.create(this.pointer));
                return;
            }
            if (dataType == DataType.FLOAT) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asFloatPointer();
                setIndexer(FloatIndexer.create(this.pointer));
            } else if (dataType == DataType.DOUBLE) {
                this.pointer = new PagedPointer(primaryBuffer, this.length).asDoublePointer();
                setIndexer(DoubleIndexer.create(this.pointer));
            } else {
                if (dataType != DataType.UTF8) {
                    throw new IllegalArgumentException("Unknown datatype: " + dataType());
                }
                this.pointer = new PagedPointer(primaryBuffer, length()).asBytePointer();
                setIndexer(ByteIndexer.create(this.pointer));
            }
        }
    }

    public DataBuffer reallocate(long j) {
        Pointer primaryBuffer = this.ptrDataBuffer.primaryBuffer();
        Pointer specialBuffer = this.ptrDataBuffer.specialBuffer();
        if (!isAttached()) {
            this.ptrDataBuffer.expand(j);
            PagedPointer pagedPointer = new PagedPointer(this.ptrDataBuffer.primaryBuffer(), j);
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType().ordinal()]) {
                case 1:
                    this.pointer = pagedPointer.asDoublePointer();
                    this.indexer = DoubleIndexer.create(this.pointer);
                    break;
                case 2:
                    this.pointer = pagedPointer.asFloatPointer();
                    this.indexer = FloatIndexer.create(this.pointer);
                    break;
                case 3:
                    this.pointer = pagedPointer.asShortPointer();
                    this.indexer = HalfIndexer.create(this.pointer);
                    break;
                case 4:
                case Nd4jCuda.UINT64 /* 14 */:
                    this.pointer = pagedPointer.asLongPointer();
                    this.indexer = LongIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.FLOAT32 /* 5 */:
                    this.pointer = pagedPointer.asIntPointer();
                    this.indexer = IntIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.DOUBLE /* 6 */:
                case 12:
                    this.pointer = pagedPointer.asShortPointer();
                    this.indexer = ShortIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.INT8 /* 7 */:
                case Nd4jCuda.INT16 /* 8 */:
                case 10:
                    this.pointer = pagedPointer.asBytePointer();
                    this.indexer = ByteIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.INT32 /* 9 */:
                    this.pointer = pagedPointer.asBoolPointer();
                    this.indexer = BooleanIndexer.create(this.pointer);
                    break;
                case Nd4jCuda.UINT8 /* 11 */:
                    this.pointer = pagedPointer.asShortPointer();
                    this.indexer = Bfloat16Indexer.create(this.pointer);
                    break;
                case Nd4jCuda.UINT32 /* 13 */:
                    this.pointer = pagedPointer.asIntPointer();
                    this.indexer = UIntIndexer.create(this.pointer);
                    break;
            }
        } else {
            long elementSize = j * getElementSize();
            if (specialBuffer != null && specialBuffer.address() != 0) {
                PagedPointer alloc = getParentWorkspace().alloc(elementSize, MemoryKind.DEVICE, dataType(), false);
                NativeOpsHolder.getInstance().getDeviceNativeOps().memcpySync(alloc, specialBuffer, j * getElementSize(), 3, (Pointer) null);
                this.ptrDataBuffer.setPrimaryBuffer(alloc, j);
                this.allocationPoint.tickDeviceRead();
            }
            if (primaryBuffer != null && primaryBuffer.address() != 0) {
                PagedPointer alloc2 = getParentWorkspace().alloc(elementSize, MemoryKind.HOST, dataType(), false);
                Pointer.memcpy(alloc2, primaryBuffer, length() * getElementSize());
                this.ptrDataBuffer.setPrimaryBuffer(alloc2, j);
                this.allocationPoint.tickHostRead();
                switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType().ordinal()]) {
                    case 1:
                        this.pointer = alloc2.asDoublePointer();
                        this.indexer = DoubleIndexer.create(this.pointer);
                        break;
                    case 2:
                        this.pointer = alloc2.asFloatPointer();
                        this.indexer = FloatIndexer.create(this.pointer);
                        break;
                    case 3:
                        this.pointer = alloc2.asShortPointer();
                        this.indexer = HalfIndexer.create(this.pointer);
                        break;
                    case 4:
                    case Nd4jCuda.UINT64 /* 14 */:
                        this.pointer = alloc2.asLongPointer();
                        this.indexer = LongIndexer.create(this.pointer);
                        break;
                    case Nd4jCuda.FLOAT32 /* 5 */:
                        this.pointer = alloc2.asIntPointer();
                        this.indexer = IntIndexer.create(this.pointer);
                        break;
                    case Nd4jCuda.DOUBLE /* 6 */:
                    case 12:
                        this.pointer = alloc2.asShortPointer();
                        this.indexer = ShortIndexer.create(this.pointer);
                        break;
                    case Nd4jCuda.INT8 /* 7 */:
                    case Nd4jCuda.INT16 /* 8 */:
                    case 10:
                        this.pointer = alloc2.asBytePointer();
                        this.indexer = ByteIndexer.create(this.pointer);
                        break;
                    case Nd4jCuda.INT32 /* 9 */:
                        this.pointer = alloc2.asBoolPointer();
                        this.indexer = BooleanIndexer.create(this.pointer);
                        break;
                    case Nd4jCuda.UINT8 /* 11 */:
                        this.pointer = alloc2.asShortPointer();
                        this.indexer = Bfloat16Indexer.create(this.pointer);
                        break;
                    case Nd4jCuda.UINT32 /* 13 */:
                        this.pointer = alloc2.asIntPointer();
                        this.indexer = UIntIndexer.create(this.pointer);
                        break;
                }
            }
            this.workspaceGenerationId = getParentWorkspace().getGenerationId();
        }
        this.underlyingLength = j;
        this.length = j;
        return this;
    }

    public long capacity() {
        return this.allocationPoint.getHostPointer() != null ? this.pointer.capacity() : this.length;
    }

    protected void release() {
        if (!this.released) {
            this.ptrDataBuffer.closeBuffer();
            this.allocationPoint.setReleased(true);
        }
        super.release();
    }

    public String getUniqueId() {
        return "BCDB_" + this.allocationPoint.getObjectId();
    }

    public Deallocator deallocator() {
        return new CudaDeallocator(this);
    }

    public int targetDevice() {
        return AtomicAllocator.getInstance().getAllocationPoint(this).getDeviceId();
    }

    public void syncToPrimary() {
        this.ptrDataBuffer.syncToPrimary();
    }

    public void syncToSpecial() {
        this.ptrDataBuffer.syncToSpecial();
    }

    public AllocationPoint getAllocationPoint() {
        return this.allocationPoint;
    }

    static {
        $assertionsDisabled = !BaseCudaDataBuffer.class.desiredAssertionStatus();
        allocator = AtomicAllocator.getInstance();
        log = LoggerFactory.getLogger(BaseCudaDataBuffer.class);
    }
}
