package org.nd4j.jita.flow.impl;

import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.concurrency.EventsProvider;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.JCublasNDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/flow/impl/SynchronousFlowController.class */
public class SynchronousFlowController implements FlowController {
    private static Logger log = LoggerFactory.getLogger(SynchronousFlowController.class);
    private volatile Allocator allocator;
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    protected EventsProvider eventsProvider = new EventsProvider();

    @Override // org.nd4j.jita.flow.FlowController
    public void init(Allocator allocator) {
        this.allocator = allocator;
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void synchronizeToHost(AllocationPoint allocationPoint) {
        if (allocationPoint.isActualOnHostSide()) {
            return;
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        if (!allocationPoint.isConstant()) {
            waitTillFinished(allocationPoint);
        }
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE && !allocationPoint.isActualOnHostSide()) {
            if (this.nativeOps.memcpyAsync(allocationPoint.getHostPointer(), allocationPoint.getDevicePointer(), AllocationUtils.getRequiredMemory(allocationPoint.getShape()), CudaConstants.cudaMemcpyDeviceToHost, cudaContext.getSpecialStream()) == 0) {
                throw new IllegalStateException("MemcpyAsync failed: " + allocationPoint.getShape());
            }
            commitTransfer(cudaContext.getSpecialStream());
        }
        allocationPoint.tickHostRead();
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void synchronizeToDevice(AllocationPoint allocationPoint) {
        if (allocationPoint.isConstant() || allocationPoint.isActualOnDeviceSide() || allocationPoint.getAllocationStatus() != AllocationStatus.DEVICE) {
            return;
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        if (this.nativeOps.memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), AllocationUtils.getRequiredMemory(allocationPoint.getShape()), CudaConstants.cudaMemcpyHostToDevice, cudaContext.getSpecialStream()) == 0) {
            throw new IllegalStateException("MemcpyAsync failed: " + allocationPoint.getShape());
        }
        commitTransfer(cudaContext.getSpecialStream());
        allocationPoint.tickDeviceRead();
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void waitTillFinished(AllocationPoint allocationPoint) {
        if (allocationPoint.getLastWriteEvent() != null) {
            allocationPoint.getLastWriteEvent().synchronize();
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public CudaContext prepareActionAllWrite(INDArray... iNDArrayArr) {
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        int intValue = this.allocator.getDeviceId().intValue();
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray != null) {
                Nd4j.getCompressor().autoDecompress(iNDArray);
                AllocationPoint allocationPoint = this.allocator.getAllocationPoint(iNDArray);
                AllocationPoint allocationPoint2 = this.allocator.getAllocationPoint(iNDArray.shapeInfoDataBuffer());
                allocationPoint.acquireLock();
                if (allocationPoint.getDeviceId() != intValue && allocationPoint.getDeviceId() >= 0) {
                    this.allocator.getMemoryHandler().relocateObject(iNDArray.data().originalDataBuffer() == null ? iNDArray.data() : iNDArray.data().originalDataBuffer());
                }
                if (allocationPoint2.getDeviceId() != intValue && allocationPoint2.getDeviceId() >= 0) {
                    ((JCublasNDArray) iNDArray).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(iNDArray.shapeInfoDataBuffer()));
                }
                prepareDelayedMemory(iNDArray);
                this.allocator.getAllocationPoint(iNDArray).setCurrentContext(cudaContext);
            }
        }
        return cudaContext;
    }

    @Override // org.nd4j.jita.flow.FlowController
    public CudaContext prepareAction(INDArray iNDArray, INDArray... iNDArrayArr) {
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        int intValue = this.allocator.getDeviceId().intValue();
        if (iNDArray != null) {
            Nd4j.getCompressor().autoDecompress(iNDArray);
            prepareDelayedMemory(iNDArray);
            AllocationPoint allocationPoint = this.allocator.getAllocationPoint(iNDArray);
            AllocationPoint allocationPoint2 = this.allocator.getAllocationPoint(iNDArray.shapeInfoDataBuffer());
            allocationPoint.acquireLock();
            if (allocationPoint.getDeviceId() != intValue && allocationPoint.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
                this.allocator.getMemoryHandler().relocateObject(iNDArray.data().originalDataBuffer() == null ? iNDArray.data() : iNDArray.data().originalDataBuffer());
            }
            if (allocationPoint2.getDeviceId() != intValue && allocationPoint2.getDeviceId() >= 0) {
                ((JCublasNDArray) iNDArray).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(iNDArray.shapeInfoDataBuffer()));
            }
            this.allocator.getAllocationPoint(iNDArray).setCurrentContext(cudaContext);
        }
        for (INDArray iNDArray2 : iNDArrayArr) {
            if (iNDArray2 != null) {
                Nd4j.getCompressor().autoDecompress(iNDArray2);
                AllocationPoint allocationPoint3 = this.allocator.getAllocationPoint(iNDArray2);
                AllocationPoint allocationPoint4 = this.allocator.getAllocationPoint(iNDArray2.shapeInfoDataBuffer());
                allocationPoint3.acquireLock();
                if (allocationPoint3.getDeviceId() != intValue && allocationPoint3.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
                    this.allocator.getMemoryHandler().relocateObject(iNDArray2.data().originalDataBuffer() == null ? iNDArray2.data() : iNDArray2.data().originalDataBuffer());
                }
                if (allocationPoint4.getDeviceId() != intValue && allocationPoint4.getDeviceId() >= 0) {
                    ((JCublasNDArray) iNDArray2).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(iNDArray2.shapeInfoDataBuffer()));
                }
                prepareDelayedMemory(iNDArray2);
                this.allocator.getAllocationPoint(iNDArray2).setCurrentContext(cudaContext);
            }
        }
        return cudaContext;
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void waitTillReleased(AllocationPoint allocationPoint) {
        waitTillFinished(allocationPoint);
        if (allocationPoint.getLastReadEvent() != null) {
            allocationPoint.getLastReadEvent().synchronize();
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void registerAction(CudaContext cudaContext, AllocationPoint allocationPoint, AllocationPoint... allocationPointArr) {
        this.eventsProvider.storeEvent(allocationPoint.getLastWriteEvent());
        allocationPoint.setLastWriteEvent(this.eventsProvider.getEvent());
        allocationPoint.getLastWriteEvent().register(cudaContext.getOldStream());
        allocationPoint.releaseLock();
        for (AllocationPoint allocationPoint2 : allocationPointArr) {
            this.eventsProvider.storeEvent(allocationPoint2.getLastReadEvent());
            allocationPoint2.setLastReadEvent(this.eventsProvider.getEvent());
            allocationPoint2.getLastReadEvent().register(cudaContext.getOldStream());
            allocationPoint2.releaseLock();
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void registerActionAllWrite(CudaContext cudaContext, INDArray... iNDArrayArr) {
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray != null) {
                AllocationPoint allocationPoint = this.allocator.getAllocationPoint(iNDArray);
                allocationPoint.tickDeviceWrite();
                this.eventsProvider.storeEvent(allocationPoint.getLastWriteEvent());
                allocationPoint.setLastWriteEvent(this.eventsProvider.getEvent());
                allocationPoint.getLastWriteEvent().register(cudaContext.getOldStream());
                allocationPoint.releaseLock();
            }
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void registerAction(CudaContext cudaContext, INDArray iNDArray, INDArray... iNDArrayArr) {
        if (iNDArray == null) {
            return;
        }
        AllocationPoint allocationPoint = this.allocator.getAllocationPoint(iNDArray);
        allocationPoint.tickDeviceWrite();
        this.eventsProvider.storeEvent(allocationPoint.getLastWriteEvent());
        allocationPoint.setLastWriteEvent(this.eventsProvider.getEvent());
        allocationPoint.getLastWriteEvent().register(cudaContext.getOldStream());
        allocationPoint.releaseLock();
        for (INDArray iNDArray2 : iNDArrayArr) {
            if (iNDArray2 != null) {
                AllocationPoint allocationPoint2 = this.allocator.getAllocationPoint(iNDArray2);
                allocationPoint2.releaseLock();
                this.eventsProvider.storeEvent(allocationPoint2.getLastReadEvent());
                allocationPoint2.setLastReadEvent(this.eventsProvider.getEvent());
                allocationPoint2.getLastReadEvent().register(cudaContext.getOldStream());
            }
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public CudaContext prepareAction(AllocationPoint allocationPoint, AllocationPoint... allocationPointArr) {
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        if (allocationPoint != null) {
            allocationPoint.acquireLock();
            allocationPoint.setCurrentContext(cudaContext);
        }
        for (AllocationPoint allocationPoint2 : allocationPointArr) {
            if (allocationPoint2 != null) {
                allocationPoint2.acquireLock();
                allocationPoint2.setCurrentContext(cudaContext);
            }
        }
        return cudaContext;
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void commitTransfer(cudaStream_t cudastream_t) {
        cudastream_t.synchronize();
    }

    protected void prepareDelayedMemory(INDArray iNDArray) {
        if (this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            AllocationPoint allocationPoint = this.allocator.getAllocationPoint(iNDArray.shapeInfoDataBuffer());
            AllocationPoint allocationPoint2 = this.allocator.getAllocationPoint(iNDArray.shapeInfoDataBuffer());
            if (allocationPoint.getAllocationStatus() != AllocationStatus.DEVICE) {
                prepareDelayedMemory(iNDArray.data());
            }
            if (allocationPoint2.getAllocationStatus() == AllocationStatus.HOST) {
                DataBuffer shapeInfoDataBuffer = iNDArray.shapeInfoDataBuffer();
                DataBuffer relocateConstantSpace = Nd4j.getConstantHandler().relocateConstantSpace(shapeInfoDataBuffer);
                if (relocateConstantSpace == shapeInfoDataBuffer) {
                    Nd4j.getConstantHandler().moveToConstantSpace(relocateConstantSpace);
                }
                ((JCublasNDArray) iNDArray).setShapeInfoDataBuffer(relocateConstantSpace);
            }
        }
    }

    protected void prepareDelayedMemory(DataBuffer dataBuffer) {
        this.allocator.getMemoryHandler().promoteObject(dataBuffer);
    }

    @Override // org.nd4j.jita.flow.FlowController
    public EventsProvider getEventsProvider() {
        return this.eventsProvider;
    }
}
