package org.nd4j.linalg.jcublas.util;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.util.ArrayList;
import java.util.List;
import jcuda.Pointer;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.complex.ComplexDouble;
import org.nd4j.linalg.jcublas.context.ContextHolder;

/* loaded from: input_file:org/nd4j/linalg/jcublas/util/KernelParamsWrapper.class */
public class KernelParamsWrapper implements AutoCloseable {
    public final Object[] kernelParameters;
    private Op resultOp;
    private boolean closeInvoked = false;
    private Multimap<INDArray, CublasPointer> arrayToPointer = ArrayListMultimap.create();
    final List<CublasPointer> pointersToFree = new ArrayList();
    final List<CublasPointer> resultPointers = new ArrayList();

    public Object[] getKernelParameters() {
        return this.kernelParameters;
    }

    public KernelParamsWrapper setResultArray(INDArray iNDArray) {
        CublasPointer cublasPointer = (CublasPointer) this.arrayToPointer.get(iNDArray).iterator().next();
        if (cublasPointer == null) {
            throw new RuntimeException("Results array must be supplied as a kernel parameter");
        }
        this.resultPointers.add(cublasPointer);
        return this;
    }

    public KernelParamsWrapper setResultOp(Accumulation accumulation, INDArray iNDArray) {
        this.resultOp = accumulation;
        setResultArray(iNDArray);
        return this;
    }

    public KernelParamsWrapper(Object... objArr) {
        this.kernelParameters = new Object[objArr.length];
        for (int i = 0; i < objArr.length; i++) {
            Object obj = objArr[i];
            if (obj instanceof JCudaBuffer) {
                CublasPointer cublasPointer = new CublasPointer((JCudaBuffer) obj);
                this.kernelParameters[i] = cublasPointer.getDevicePointer();
                this.pointersToFree.add(cublasPointer);
            } else if (obj instanceof INDArray) {
                INDArray iNDArray = (INDArray) obj;
                CublasPointer cublasPointer2 = new CublasPointer(iNDArray);
                this.kernelParameters[i] = cublasPointer2.getDevicePointer();
                this.pointersToFree.add(cublasPointer2);
                this.arrayToPointer.put(iNDArray, cublasPointer2);
            } else {
                this.kernelParameters[i] = obj;
            }
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        ContextHolder.syncStream();
        if (this.closeInvoked) {
            return;
        }
        for (CublasPointer cublasPointer : this.pointersToFree) {
            if (this.resultPointers.contains(cublasPointer)) {
                if (this.resultOp != null) {
                    setResultForOp(this.resultOp, cublasPointer);
                } else {
                    cublasPointer.copyToHost();
                }
            }
            cublasPointer.close();
        }
        JCudaDriver.cuMemGetInfo(new long[1], new long[1]);
        this.closeInvoked = true;
    }

    private void setResultForOp(Op op, CublasPointer cublasPointer) {
        if (cublasPointer.getBuffer().dataType() != DataBuffer.Type.DOUBLE) {
            float[] fArr = new float[2];
            Pointer pointer = Pointer.to(fArr);
            ContextHolder.syncStream();
            JCuda.cudaMemcpyAsync(pointer, cublasPointer.getDevicePointer(), 8L, 2, ContextHolder.getInstance().getCudaStream());
            if (op instanceof Accumulation) {
                Accumulation accumulation = (Accumulation) op;
                accumulation.setCurrentResult(Float.valueOf(fArr[0]));
                accumulation.setCurrentResultComplex(new ComplexDouble(fArr[0], fArr[1]));
                return;
            }
            return;
        }
        double[] dArr = new double[2];
        Pointer pointer2 = Pointer.to(dArr);
        ContextHolder.syncStream();
        JCuda.cudaMemcpyAsync(pointer2, cublasPointer.getDevicePointer(), 16L, 2, ContextHolder.getInstance().getCudaStream());
        ContextHolder.syncStream();
        if (op instanceof Accumulation) {
            Accumulation accumulation2 = (Accumulation) op;
            accumulation2.setCurrentResult(Double.valueOf(dArr[0]));
            accumulation2.setCurrentResultComplex(new ComplexDouble(dArr[0], dArr[1]));
        }
    }
}
