package org.nd4j.linalg.jcublas;

import jcublas.cublasHandle;
import jcuda.CudaException;
import jcuda.LogLevel;
import jcuda.cuComplex;
import jcuda.cuDoubleComplex;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import jcuda.runtime.cudaError;
import jcuda.utils.KernelLauncher;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.factory.DataTypeValidation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.kernel.KernelFunctionLoader;

/* loaded from: input_file:org/nd4j/linalg/jcublas/SimpleJCublas.class */
public class SimpleJCublas {
    private static boolean init = false;
    private static cublasHandle handle = new cublasHandle();

    public static void assertCudaBuffer(INDArray... iNDArrayArr) {
        for (INDArray iNDArray : iNDArrayArr) {
            if (!(iNDArray.data() instanceof JCudaBuffer)) {
                throw new IllegalArgumentException("Unable to allocate pointer for buffer of type " + iNDArrayArr.getClass().toString());
            }
        }
    }

    public static void assertCudaBuffer(DataBuffer... dataBufferArr) {
        for (DataBuffer dataBuffer : dataBufferArr) {
            if (!(dataBuffer instanceof JCudaBuffer)) {
                throw new IllegalArgumentException("Unable to allocate pointer for buffer of type " + dataBufferArr.getClass().toString());
            }
        }
    }

    public static cublasHandle handle() {
        return handle;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int checkResult(int i) {
        if (i != 0) {
            throw new CudaException(cudaError.stringFor(i));
        }
        return i;
    }

    public static void init() {
        if (init) {
            return;
        }
        JCublas.setLogLevel(LogLevel.LOG_DEBUG);
        JCublas.setExceptionsEnabled(true);
        try {
            KernelFunctionLoader.getInstance().load();
            cudaDeviceProp cudadeviceprop = new cudaDeviceProp();
            checkResult(JCuda.cudaGetDeviceProperties(cudadeviceprop, 0));
            if (cudadeviceprop.canMapHostMemory == 0) {
                System.err.println("This device can not map host memory");
                System.err.println(cudadeviceprop.toFormattedString());
            } else {
                JCudaDriver.cuCtxGetApiVersion(ContextHolder.getInstance().getContext(), new int[1]);
                init = true;
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static void sync() {
        checkResult(JCuda.cudaDeviceSynchronize());
        KernelLauncher.setContext();
    }

    public static INDArray gemv(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, double d, double d2) {
        DataTypeValidation.assertDouble(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        assertCudaBuffer(iNDArray.data(), iNDArray2.data(), iNDArray3.data());
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer(iNDArray3);
        JCublas.cublasDgemv('N', iNDArray.rows(), iNDArray.columns(), d, cublasPointer, iNDArray.rows(), cublasPointer2, 1, d2, cublasPointer3, 1);
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        sync();
        return iNDArray3;
    }

    public static INDArray gemv(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f, float f2) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer(iNDArray3);
        JCublas.cublasSgemv('N', iNDArray.rows(), iNDArray.columns(), f, cublasPointer, iNDArray.rows(), cublasPointer2, 1, f2, cublasPointer3, 1);
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer3, cublasPointer, cublasPointer2);
        return iNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gemv(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer((INDArray) iComplexNDArray3);
        JCublas.cublasZgemv('n', iComplexNDArray.rows(), iComplexNDArray.rows(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble2.imaginaryComponent().doubleValue()), cublasPointer, iComplexNDArray.rows(), cublasPointer2, iComplexNDArray2.secondaryStride(), cuDoubleComplex.cuCmplx(iComplexDouble2.realComponent().doubleValue(), iComplexDouble2.imaginaryComponent().doubleValue()), cublasPointer3, iComplexNDArray3.secondaryStride());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gemv(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray3, IComplexFloat iComplexFloat2) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        assertCudaBuffer(iComplexNDArray, iComplexNDArray2, iComplexNDArray3);
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer((INDArray) iComplexNDArray3);
        JCublas.cublasCgemv('n', iComplexNDArray.rows(), iComplexNDArray.columns(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat2.imaginaryComponent().floatValue()), cublasPointer, iComplexNDArray.rows(), cublasPointer2, iComplexNDArray2.secondaryStride(), cuComplex.cuCmplx(iComplexFloat2.realComponent().floatValue(), iComplexFloat2.imaginaryComponent().floatValue()), cublasPointer3, iComplexNDArray3.secondaryStride());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gemm(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer((INDArray) iComplexNDArray3);
        JCublas.cublasZgemm('n', 'n', iComplexNDArray3.rows(), iComplexNDArray3.columns(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble2.imaginaryComponent().doubleValue()), cublasPointer, iComplexNDArray.rows(), cublasPointer2, iComplexNDArray2.rows(), cuDoubleComplex.cuCmplx(iComplexDouble2.realComponent().doubleValue(), iComplexDouble2.imaginaryComponent().doubleValue()), cublasPointer3, iComplexNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gemm(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray3, IComplexFloat iComplexFloat2) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer((INDArray) iComplexNDArray3);
        JCublas.cublasCgemm('n', 'n', iComplexNDArray3.rows(), iComplexNDArray3.columns(), iComplexNDArray.columns(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat2.imaginaryComponent().floatValue()), cublasPointer, iComplexNDArray.rows(), cublasPointer2, iComplexNDArray2.rows(), cuComplex.cuCmplx(iComplexFloat2.realComponent().floatValue(), iComplexFloat2.imaginaryComponent().floatValue()), cublasPointer3, iComplexNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iComplexNDArray3;
    }

    public static INDArray gemm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, double d, double d2) {
        DataTypeValidation.assertDouble(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer(iNDArray3);
        JCublas.cublasDgemm('n', 'n', iNDArray3.rows(), iNDArray3.columns(), iNDArray.columns(), d, cublasPointer, iNDArray.rows(), cublasPointer2, iNDArray2.rows(), d2, cublasPointer3, iNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iNDArray3;
    }

    public static INDArray gemm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f, float f2) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer(iNDArray3);
        JCublas.cublasSgemm('n', 'n', iNDArray3.rows(), iNDArray3.columns(), iNDArray.columns(), f, cublasPointer, iNDArray.rows(), cublasPointer2, iNDArray2.rows(), f2, cublasPointer3, iNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iNDArray3;
    }

    public static double nrm2(IComplexNDArray iComplexNDArray) {
        sync();
        return iComplexNDArray.data().dataType() == DataBuffer.Type.FLOAT ? JCublas.cublasSnrm2(iComplexNDArray.length(), r0, 2) : JCublas.cublasDnrm2(iComplexNDArray.length(), new CublasPointer((INDArray) iComplexNDArray), 2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void copy(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        JCudaBuffer jCudaBuffer = (JCudaBuffer) iComplexNDArray.data();
        if (iComplexNDArray.majorStride() == 2 && iComplexNDArray2.majorStride() == 2) {
            JCuda.cudaMemcpy(cublasPointer2, cublasPointer, iComplexNDArray.length() * jCudaBuffer.getElementSize() * 2, 3);
        } else {
            Nd4j.getExecutioner().exec(new CopyOp(iComplexNDArray, iComplexNDArray2, iComplexNDArray2, iComplexNDArray.length()));
        }
        sync();
        cublasPointer2.copyToHost();
        releaseCublasPointers(cublasPointer2, cublasPointer);
    }

    public static int iamax(IComplexNDArray iComplexNDArray) {
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        return iComplexNDArray.data().dataType() == DataBuffer.Type.FLOAT ? JCublas.cublasIsamax(iComplexNDArray.length(), cublasPointer, 1) : JCublas.cublasIzamax(iComplexNDArray.length(), cublasPointer, 1);
    }

    public static float asum(IComplexNDArray iComplexNDArray) {
        return JCublas.cublasScasum(iComplexNDArray.length(), new CublasPointer((INDArray) iComplexNDArray), 1);
    }

    public static void swap(INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iNDArray, iNDArray2});
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        sync();
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            JCublas.cublasSswap(iNDArray.length(), cublasPointer, 1, cublasPointer2, 1);
        } else {
            JCublas.cublasDswap(iNDArray.length(), cublasPointer, 1, cublasPointer2, 1);
        }
        sync();
    }

    public static double asum(INDArray iNDArray) {
        return iNDArray.data().dataType() == DataBuffer.Type.FLOAT ? JCublas.cublasSasum(iNDArray.length(), r0, 1) : JCublas.cublasDasum(iNDArray.length(), new CublasPointer(iNDArray), 1);
    }

    public static double nrm2(INDArray iNDArray) {
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            return JCublas.cublasSnrm2(iNDArray.length(), new CublasPointer(iNDArray), 1);
        }
        if (iNDArray.data().dataType() != DataBuffer.Type.DOUBLE) {
            throw new IllegalStateException("Illegal data type on array ");
        }
        return JCublas.cublasDnrm2(iNDArray.length(), new CublasPointer(iNDArray), 1);
    }

    public static int iamax(INDArray iNDArray) {
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            return JCublas.cublasIsamax(iNDArray.length(), cublasPointer, iNDArray.majorStride()) - 1;
        }
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            return JCublas.cublasIdamax(iNDArray.length(), cublasPointer, iNDArray.majorStride()) - 1;
        }
        throw new IllegalStateException("Illegal data type on array ");
    }

    public static void axpy(float f, INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2});
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        sync();
        JCublas.cublasSaxpy(iNDArray.length(), f, cublasPointer, iNDArray.majorStride(), cublasPointer2, iNDArray2.majorStride());
        ((JCudaBuffer) iNDArray.data()).copyToHost();
        sync();
        cublasPointer2.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void axpy(IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        sync();
        JCublas.cublasCaxpy(iComplexNDArray.length(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), cublasPointer, 1, cublasPointer2, 1);
        sync();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void axpy(IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        DataTypeValidation.assertDouble(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        sync();
        JCublas.cublasZaxpy(iComplexNDArray.length(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().floatValue(), iComplexDouble.imaginaryComponent().floatValue()), cublasPointer, iComplexNDArray.majorStride(), cublasPointer2, iComplexNDArray2.majorStride());
        sync();
    }

    public static INDArray scal(double d, INDArray iNDArray) {
        DataTypeValidation.assertDouble(iNDArray);
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        JCublas.cublasDscal(iNDArray.length(), d, cublasPointer, iNDArray.majorStride());
        sync();
        cublasPointer.copyToHost();
        releaseCublasPointers(cublasPointer);
        return iNDArray;
    }

    public static INDArray scal(float f, INDArray iNDArray) {
        DataTypeValidation.assertFloat(iNDArray);
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        JCublas.cublasSscal(iNDArray.length(), f, cublasPointer, iNDArray.majorStride());
        sync();
        cublasPointer.copyToHost();
        releaseCublasPointers(cublasPointer);
        return iNDArray;
    }

    public static void copy(INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iNDArray, iNDArray2});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            JCublas.cublasDcopy(iNDArray.length(), cublasPointer, iNDArray.majorStride(), cublasPointer2, iNDArray2.majorStride());
        }
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            JCublas.cublasScopy(iNDArray.length(), cublasPointer, iNDArray.majorStride(), cublasPointer2, iNDArray2.majorStride());
        }
        sync();
        cublasPointer2.copyToHost();
        releaseCublasPointers(cublasPointer2, cublasPointer);
    }

    public static double dot(INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iNDArray, iNDArray2});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            float cublasSdot = JCublas.cublasSdot(iNDArray.length(), cublasPointer, iNDArray.majorStride(), cublasPointer2, iNDArray2.majorStride());
            sync();
            releaseCublasPointers(cublasPointer, cublasPointer2);
            return cublasSdot;
        }
        double cublasDdot = JCublas.cublasDdot(iNDArray.length(), cublasPointer, iNDArray2.majorStride(), cublasPointer2, iNDArray2.majorStride());
        sync();
        releaseCublasPointers(cublasPointer, cublasPointer2);
        return cublasDdot;
    }

    private static void releaseCublasPointers(CublasPointer... cublasPointerArr) {
        for (CublasPointer cublasPointer : cublasPointerArr) {
            if (cublasPointer != null) {
                try {
                    cublasPointer.close();
                } catch (Exception e) {
                    throw new RuntimeException("Could not run cublas command", e);
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexDouble dot(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        cuDoubleComplex cublasZdotc = JCublas.cublasZdotc(iComplexNDArray.length(), cublasPointer, iComplexNDArray.majorStride(), cublasPointer2, iComplexNDArray2.majorStride());
        IComplexDouble createDouble = Nd4j.createDouble(cublasZdotc.x, cublasZdotc.y);
        sync();
        releaseCublasPointers(cublasPointer, cublasPointer2);
        return createDouble;
    }

    public static INDArray ger(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, double d) {
        DataTypeValidation.assertDouble(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer(iNDArray3);
        JCublas.cublasDger(iNDArray.rows(), iNDArray.columns(), d, cublasPointer, iNDArray.rows(), cublasPointer2, iNDArray2.rows(), cublasPointer3, iNDArray3.rows());
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        sync();
        return iNDArray3;
    }

    public static INDArray ger(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer(iNDArray3);
        JCublas.cublasSger(iNDArray.rows(), iNDArray.columns(), f, cublasPointer, iNDArray.rows(), cublasPointer2, iNDArray2.rows(), cublasPointer3, iNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iNDArray3;
    }

    public static IComplexNDArray scal(IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray) {
        DataTypeValidation.assertFloat(iComplexNDArray);
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        JCublas.cublasCscal(iComplexNDArray.length(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), cublasPointer, iComplexNDArray.majorStride());
        sync();
        cublasPointer.copyToHost();
        releaseCublasPointers(cublasPointer);
        return iComplexNDArray;
    }

    public static IComplexNDArray scal(IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray) {
        DataTypeValidation.assertDouble(iComplexNDArray);
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        JCublas.cublasZscal(iComplexNDArray.length(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), cublasPointer, iComplexNDArray.majorStride());
        sync();
        cublasPointer.copyToHost();
        releaseCublasPointers(cublasPointer);
        return iComplexNDArray;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexDouble dotu(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        IComplexDouble createDouble;
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        if (iComplexNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            cuDoubleComplex cublasZdotu = JCublas.cublasZdotu(iComplexNDArray.length(), cublasPointer, iComplexNDArray.majorStride(), cublasPointer2, iComplexNDArray2.majorStride());
            createDouble = Nd4j.createDouble(cublasZdotu.x, cublasZdotu.y);
        } else {
            cuComplex cublasCdotu = JCublas.cublasCdotu(iComplexNDArray.length(), cublasPointer, iComplexNDArray.majorStride(), cublasPointer2, iComplexNDArray2.majorStride());
            createDouble = Nd4j.createDouble(cublasCdotu.x, cublasCdotu.y);
        }
        sync();
        releaseCublasPointers(cublasPointer, cublasPointer2);
        return createDouble;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray geru(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble) {
        sync();
        DataTypeValidation.assertDouble(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer((INDArray) iComplexNDArray3);
        JCublas.cublasZgeru(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), cublasPointer, iComplexNDArray.rows(), cublasPointer2, iComplexNDArray2.rows(), cublasPointer3, iComplexNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gerc(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexFloat iComplexFloat) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer((INDArray) iComplexNDArray3);
        JCublas.cublasCgerc(iComplexNDArray.rows(), iComplexNDArray.columns(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), cublasPointer, iComplexNDArray.rows(), cublasPointer2, iComplexNDArray2.rows(), cublasPointer3, iComplexNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray geru(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexFloat iComplexFloat) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer((INDArray) iComplexNDArray3);
        JCublas.cublasZgeru(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), cublasPointer, iComplexNDArray.rows(), cublasPointer2, iComplexNDArray2.rows(), cublasPointer3, iComplexNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gerc(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble) {
        DataTypeValidation.assertDouble(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        sync();
        CublasPointer cublasPointer = new CublasPointer((INDArray) iComplexNDArray);
        CublasPointer cublasPointer2 = new CublasPointer((INDArray) iComplexNDArray2);
        CublasPointer cublasPointer3 = new CublasPointer((INDArray) iComplexNDArray3);
        JCublas.cublasZgerc(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), cublasPointer, iComplexNDArray.rows(), cublasPointer2, iComplexNDArray2.rows(), cublasPointer3, iComplexNDArray3.rows());
        sync();
        cublasPointer3.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2, cublasPointer3);
        return iComplexNDArray3;
    }

    public static void axpy(double d, INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertDouble(new INDArray[]{iNDArray, iNDArray2});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        JCublas.cublasDaxpy(iNDArray.length(), d, cublasPointer, iNDArray.majorStride(), cublasPointer2, iNDArray2.majorStride());
        sync();
        cublasPointer2.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2);
    }

    public static void saxpy(float f, INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2});
        sync();
        CublasPointer cublasPointer = new CublasPointer(iNDArray);
        CublasPointer cublasPointer2 = new CublasPointer(iNDArray2);
        JCublas.cublasSaxpy(iNDArray.length(), f, cublasPointer, iNDArray.majorStride(), cublasPointer2, iNDArray2.majorStride());
        sync();
        cublasPointer.copyToHost();
        releaseCublasPointers(cublasPointer, cublasPointer2);
    }

    static {
        init();
    }
}
