package org.nd4j.linalg.jcublas.buffer;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import jcuda.Pointer;
import jcuda.jcublas.JCublas2;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.blas.BlasBufferUtil;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.complex.CudaComplexConversion;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.util.PointerUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.class */
public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCudaBuffer {
    static AtomicLong allocated = new AtomicLong();
    static AtomicLong totalAllocated = new AtomicLong();
    private static Logger log = LoggerFactory.getLogger(BaseCudaDataBuffer.class);
    protected transient Table<String, Pair<Integer, Integer>, DevicePointerInfo> pointersToContexts;
    protected AtomicBoolean modified;
    protected Collection<String> referencing;
    protected transient WeakReference<DataBuffer> ref;
    protected AtomicBoolean freed;
    private transient Pointer hostPointer;
    private Map<String, Boolean> copied;

    /* loaded from: input_file:org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer$DevicePointerInfo.class */
    public static class DevicePointerInfo {
        private final Pointer pointer;
        private final long length;
        private final int stride;
        private final int offset;
        private boolean freed = false;

        public DevicePointerInfo(Pointer pointer, long j, int i, int i2) {
            this.pointer = pointer;
            this.length = j;
            this.stride = i;
            this.offset = i2;
        }

        public boolean isFreed() {
            return this.freed;
        }

        public void setFreed(boolean z) {
            this.freed = z;
        }

        public int getOffset() {
            return this.offset;
        }

        public int getStride() {
            return this.stride;
        }

        public Pointer getPointer() {
            return this.pointer;
        }

        public long getLength() {
            return this.length;
        }
    }

    public BaseCudaDataBuffer(ByteBuf byteBuf, int i) {
        super(byteBuf, i);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(float[] fArr, boolean z) {
        super(fArr, z);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(double[] dArr, boolean z) {
        super(dArr, z);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(int[] iArr, boolean z) {
        super(iArr, z);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(int i, int i2) {
        super(i, i2);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(int i) {
        super(i);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(float[] fArr) {
        super(fArr);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(int[] iArr) {
        super(iArr);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(double[] dArr) {
        super(dArr);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    public BaseCudaDataBuffer(byte[] bArr, int i) {
        super(bArr, i);
        this.pointersToContexts = HashBasedTable.create();
        this.modified = new AtomicBoolean(false);
        this.referencing = Collections.synchronizedSet(new HashSet());
        this.freed = new AtomicBoolean(false);
        this.copied = new ConcurrentHashMap();
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public boolean copied(String str) {
        if (this.copied.get(str) == null) {
            return false;
        }
        return this.copied.get(str).booleanValue();
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public void setCopied(String str) {
        this.copied.put(str, true);
    }

    public DataBuffer.AllocationMode allocationMode() {
        return this.allocationMode;
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public ByteBuffer getHostBuffer() {
        return this.dataBuffer.nioBuffer();
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public void setHostBuffer(ByteBuffer byteBuffer) {
        this.dataBuffer = Unpooled.wrappedBuffer(byteBuffer);
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public Pointer getHostPointer() {
        if (this.hostPointer == null) {
            this.hostPointer = Pointer.to(asNio());
        }
        return this.hostPointer;
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public Pointer getHostPointer(int i) {
        if (this.hostPointer == null) {
            this.hostPointer = Pointer.to(asNio());
        }
        return this.hostPointer.withByteOffset(i * getElementSize());
    }

    public void removeReferencing(String str) {
        this.referencing.remove(str);
    }

    public Collection<String> references() {
        return this.referencing;
    }

    public int getElementSize() {
        return this.elementSize;
    }

    public void addReferencing(String str) {
        this.referencing.add(str);
    }

    public void put(int i, IComplexNumber iComplexNumber) {
        this.modified.set(true);
        if (dataType() == DataBuffer.Type.FLOAT) {
            JCublas2.cublasSetVector(length(), getElementSize(), PointerUtil.getPointer(CudaComplexConversion.toComplex(iComplexNumber.asFloat())), 1, getHostPointer(), 1);
        } else {
            JCublas2.cublasSetVector(length(), getElementSize(), PointerUtil.getPointer(CudaComplexConversion.toComplexDouble(iComplexNumber.asDouble())), 1, getHostPointer(), 1);
        }
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public Pointer getDevicePointer(int i, int i2, int i3) {
        String name = Thread.currentThread().getName();
        DevicePointerInfo devicePointerInfo = (DevicePointerInfo) this.pointersToContexts.get(name, new Pair(Integer.valueOf(i2), Integer.valueOf(i3)));
        if (devicePointerInfo == null) {
            int elementSize = getElementSize() * i3;
            allocated.addAndGet(elementSize);
            totalAllocated.addAndGet(elementSize);
            log.trace("Allocating {} bytes, total: {}, overall: {}", new Object[]{Integer.valueOf(elementSize), Long.valueOf(allocated.get()), totalAllocated});
            if (devicePointerInfo == null) {
                if (!this.pointersToContexts.contains(name, new Pair(0, Integer.valueOf(this.length)))) {
                    devicePointerInfo = (DevicePointerInfo) ContextHolder.getInstance().getConf().getMemoryStrategy().alloc(this, 1, 0, this.length);
                    this.pointersToContexts.put(name, new Pair(0, Integer.valueOf(this.length)), devicePointerInfo);
                }
                if (i2 > 0) {
                    Pointer pointer = ((DevicePointerInfo) this.pointersToContexts.get(name, 0)).getPointer();
                    Pointer withByteOffset = ((DevicePointerInfo) this.pointersToContexts.get(name, 0)).getPointer().withByteOffset(i2 * getElementSize());
                    this.pointersToContexts.put(name, new Pair(Integer.valueOf(i2), Integer.valueOf(i3)), new DevicePointerInfo(pointer, i3, i, i2));
                    return withByteOffset;
                }
            }
            this.freed.set(false);
        }
        return devicePointerInfo.getPointer().withByteOffset(i2 * getElementSize());
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public Pointer getDevicePointer(INDArray iNDArray, int i, int i2, int i3) {
        String name = Thread.currentThread().getName();
        DevicePointerInfo devicePointerInfo = (DevicePointerInfo) this.pointersToContexts.get(name, new Pair(Integer.valueOf(i2), Integer.valueOf(i3)));
        if (devicePointerInfo == null) {
            int elementSize = getElementSize() * i3;
            allocated.addAndGet(elementSize);
            totalAllocated.addAndGet(elementSize);
            log.trace("Allocating {} bytes, total: {}, overall: {}", new Object[]{Integer.valueOf(elementSize), Long.valueOf(allocated.get()), totalAllocated});
            if (iNDArray.data() != this) {
                throw new IllegalArgumentException("Unable to get pointer for array that doesn't have this as the buffer");
            }
            int length = iNDArray instanceof IComplexNDArray ? iNDArray.length() * 2 : iNDArray.length();
            if (!this.pointersToContexts.contains(name, new Pair(0, Integer.valueOf(this.length)))) {
                devicePointerInfo = (DevicePointerInfo) ContextHolder.getInstance().getConf().getMemoryStrategy().alloc(this, 1, 0, this.length);
                this.pointersToContexts.put(name, new Pair(0, Integer.valueOf(this.length)), devicePointerInfo);
            }
            if (i2 > 0) {
                DevicePointerInfo devicePointerInfo2 = (DevicePointerInfo) this.pointersToContexts.get(name, new Pair(0, Integer.valueOf(this.length)));
                if (devicePointerInfo2 == null) {
                    throw new IllegalStateException("No pointer found for name " + name + " and offset/length " + i2 + " / " + i3);
                }
                Pointer pointer = devicePointerInfo2.getPointer();
                Pointer withByteOffset = devicePointerInfo2.getPointer().withByteOffset(i2 * getElementSize());
                this.pointersToContexts.put(name, new Pair(Integer.valueOf(i2), Integer.valueOf(length)), new DevicePointerInfo(pointer, i3, i, i2));
                return withByteOffset;
            }
            if (i2 == 0 && length < iNDArray.data().length()) {
                DevicePointerInfo devicePointerInfo3 = new DevicePointerInfo(((DevicePointerInfo) this.pointersToContexts.get(name, new Pair(0, Integer.valueOf(this.length)))).getPointer(), this.length, BlasBufferUtil.getBlasStride(iNDArray), iNDArray.offset());
                this.pointersToContexts.put(name, new Pair(Integer.valueOf(i2), Integer.valueOf(iNDArray instanceof IComplexNDArray ? iNDArray.length() * 2 : iNDArray.length())), devicePointerInfo3);
                return devicePointerInfo3.getPointer();
            }
            this.freed.set(false);
        }
        return (devicePointerInfo == null && i2 == 0 && i3 < length()) ? new DevicePointerInfo(((DevicePointerInfo) this.pointersToContexts.get(Thread.currentThread().getName(), new Pair(0, Integer.valueOf(length())))).getPointer(), i3, i, 0).getPointer() : devicePointerInfo.getPointer().withByteOffset(i2 * getElementSize());
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public void set(Pointer pointer) {
        this.modified.set(true);
        if (dataType() == DataBuffer.Type.DOUBLE) {
            JCublas2.cublasDcopy(ContextHolder.getInstance().getHandle(), length(), pointer, 1, getHostPointer(), 1);
        } else {
            JCublas2.cublasScopy(ContextHolder.getInstance().getHandle(), length(), pointer, 1, getHostPointer(), 1);
        }
    }

    public IComplexFloat getComplexFloat(int i) {
        return Nd4j.createFloat(getFloat(i), getFloat(i + 1));
    }

    public IComplexDouble getComplexDouble(int i) {
        return Nd4j.createDouble(getDouble(i), getDouble(i + 1));
    }

    public IComplexNumber getComplex(int i) {
        return dataType() == DataBuffer.Type.FLOAT ? getComplexFloat(i) : getComplexDouble(i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void set(int i, int i2, Pointer pointer, int i3) {
        this.modified.set(true);
        int elementSize = getElementSize() * i;
        if (elementSize >= length() * getElementSize()) {
            throw new IllegalArgumentException("Illegal offset " + elementSize + " with index of " + i + " and length " + length());
        }
        JCublas2.cublasSetVectorAsync(i2, getElementSize(), pointer, i3, getHostPointer().withByteOffset(elementSize), 1, ContextHolder.getInstance().getCudaStream());
        ContextHolder.syncStream();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void set(int i, int i2, Pointer pointer) {
        set(i, i2, pointer, 1);
    }

    public void assign(DataBuffer dataBuffer) {
        set(0, ((JCudaBuffer) dataBuffer).getHostPointer());
    }

    protected void set(int i, Pointer pointer) {
        set(i, 1, pointer);
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public boolean freeDevicePointer(int i, int i2) {
        String name = Thread.currentThread().getName();
        DevicePointerInfo devicePointerInfo = (DevicePointerInfo) this.pointersToContexts.get(name, Integer.valueOf(i));
        if (i != 0) {
            this.pointersToContexts.remove(name, Integer.valueOf(i));
            return false;
        }
        if (i == 0 && this.isPersist) {
            return true;
        }
        if (devicePointerInfo == null || this.freed.get()) {
            return false;
        }
        allocated.addAndGet(-devicePointerInfo.getLength());
        log.trace("freeing {} bytes, total: {}", Long.valueOf(devicePointerInfo.getLength()), Long.valueOf(allocated.get()));
        ContextHolder.getInstance().getMemoryStrategy().free(this, i, i2);
        this.freed.set(true);
        this.copied.remove(name);
        this.pointersToContexts.remove(name, Integer.valueOf(i));
        return true;
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public void copyToHost(int i, int i2) {
        DevicePointerInfo devicePointerInfo = (DevicePointerInfo) this.pointersToContexts.get(Thread.currentThread().getName(), new Pair(Integer.valueOf(i), Integer.valueOf(i2)));
        if (devicePointerInfo == null) {
            throw new IllegalStateException("No pointer found for offset " + i);
        }
        if (devicePointerInfo.getOffset() != i) {
            throw new IllegalStateException("Device pointer offset didn't match specified offset in pointer map");
        }
        if (devicePointerInfo == null) {
            throw new IllegalStateException("No offset found to copy");
        }
        ContextHolder.syncStream();
        int stride = devicePointerInfo.getStride();
        int offset = devicePointerInfo.getOffset();
        long length = devicePointerInfo.getLength();
        if (offset != 0 || i2 >= length()) {
            JCublas2.cublasGetVectorAsync((int) length, getElementSize(), devicePointerInfo.getPointer().withByteOffset(i * getElementSize()), stride, getHostPointer(offset), stride, ContextHolder.getInstance().getCudaStream());
        } else {
            JCublas2.cublasGetVectorAsync(i2, getElementSize(), devicePointerInfo.getPointer().withByteOffset(i * getElementSize()), stride, getHostPointer(offset), stride, ContextHolder.getInstance().getCudaStream());
        }
        ContextHolder.syncStream();
    }

    public void flush() {
        throw new UnsupportedOperationException();
    }

    public void destroy() {
        this.dataBuffer.clear();
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        write(objectOutputStream);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        doReadObject(objectInputStream);
        this.copied = new HashMap();
        this.pointersToContexts = HashBasedTable.create();
        this.ref = new WeakReference<>(this, Nd4j.bufferRefQueue());
        this.freed = new AtomicBoolean(false);
    }

    @Override // org.nd4j.linalg.jcublas.buffer.JCudaBuffer
    public Table<String, Pair<Integer, Integer>, DevicePointerInfo> getPointersToContexts() {
        return this.pointersToContexts;
    }

    public void setPointersToContexts(Table<String, Pair<Integer, Integer>, DevicePointerInfo> table) {
        this.pointersToContexts = table;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("[");
        for (int i = 0; i < length(); i++) {
            stringBuffer.append(getDouble(i));
            if (i < length() - 1) {
                stringBuffer.append(",");
            }
        }
        stringBuffer.append("]");
        return stringBuffer.toString();
    }
}
