package org.nd4j.linalg.jcublas.blas;

import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cusolver.cusolverDnContext;
import org.bytedeco.cuda.global.cusolver;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.linalg.api.blas.BlasException;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/blas/JcublasLapack.class */
public class JcublasLapack extends BaseLapack {
    private static final Logger log = LoggerFactory.getLogger(JcublasLapack.class);
    private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private Allocator allocator = AtomicAllocator.getInstance();

    /* loaded from: input_file:org/nd4j/linalg/jcublas/blas/JcublasLapack$Workspace.class */
    static class Workspace extends Pointer {
        public Workspace(long j) {
            super(NativeOpsHolder.getInstance().getDeviceNativeOps().mallocDevice(j, 0, 0));
            deallocator(new Pointer.Deallocator() { // from class: org.nd4j.linalg.jcublas.blas.JcublasLapack.Workspace.1
                public void deallocate() {
                    NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(Workspace.this, 0);
                }
            });
        }
    }

    public void sgetrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray iNDArray4 = iNDArray;
        if (Nd4j.dataType() != DataType.FLOAT) {
            log.warn("FLOAT getrf called in DOUBLE environment");
        }
        if (iNDArray.ordering() == 'c') {
            iNDArray4 = iNDArray.dup('f');
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            if (cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream())) != 0) {
                throw new BlasException("solverSetStream failed");
            }
            CublasPointer cublasPointer = new CublasPointer(iNDArray4, cudaContext);
            int cusolverDnSgetrf_bufferSize = cusolver.cusolverDnSgetrf_bufferSize(cusolverdncontext, i, i2, cublasPointer.getDevicePointer(), i, Nd4j.getDataBufferFactory().createInt(1L).addressPointer());
            if (cusolverDnSgetrf_bufferSize != 0) {
                throw new BlasException("cusolverDnSgetrf_bufferSize failed", cusolverDnSgetrf_bufferSize);
            }
            int cusolverDnSgetrf = cusolver.cusolverDnSgetrf(cusolverdncontext, i, i2, cublasPointer.getDevicePointer(), i, new CudaPointer(new Workspace(r0.getInt(0L) * Nd4j.sizeOfDataType())).asFloatPointer(), new CudaPointer(this.allocator.getPointer(iNDArray2, cudaContext)).asIntPointer(), new CudaPointer(this.allocator.getPointer(iNDArray3, cudaContext)).asIntPointer());
            if (cusolverDnSgetrf != 0) {
                throw new BlasException("cusolverDnSgetrf failed", cusolverDnSgetrf);
            }
        }
        this.allocator.registerAction(cudaContext, iNDArray4, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray2, new INDArray[0]);
        if (iNDArray4 != iNDArray) {
            iNDArray.assign(iNDArray4);
        }
    }

    public void dgetrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray iNDArray4 = iNDArray;
        if (Nd4j.dataType() != DataType.DOUBLE) {
            log.warn("FLOAT getrf called in FLOAT environment");
        }
        if (iNDArray.ordering() == 'c') {
            iNDArray4 = iNDArray.dup('f');
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            if (cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream())) != 0) {
                throw new BlasException("solverSetStream failed");
            }
            CublasPointer cublasPointer = new CublasPointer(iNDArray4, cudaContext);
            int cusolverDnDgetrf_bufferSize = cusolver.cusolverDnDgetrf_bufferSize(cusolverdncontext, i, i2, cublasPointer.getDevicePointer(), i, Nd4j.getDataBufferFactory().createInt(1L).addressPointer());
            if (cusolverDnDgetrf_bufferSize != 0) {
                throw new BlasException("cusolverDnDgetrf_bufferSize failed", cusolverDnDgetrf_bufferSize);
            }
            int cusolverDnDgetrf = cusolver.cusolverDnDgetrf(cusolverdncontext, i, i2, cublasPointer.getDevicePointer(), i, new CudaPointer(new Workspace(r0.getInt(0L) * Nd4j.sizeOfDataType())).asDoublePointer(), new CudaPointer(this.allocator.getPointer(iNDArray2, cudaContext)).asIntPointer(), new CudaPointer(this.allocator.getPointer(iNDArray3, cudaContext)).asIntPointer());
            if (cusolverDnDgetrf != 0) {
                throw new BlasException("cusolverDnSgetrf failed", cusolverDnDgetrf);
            }
        }
        this.allocator.registerAction(cudaContext, iNDArray4, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray2, new INDArray[0]);
        if (iNDArray4 != iNDArray) {
            iNDArray.assign(iNDArray4);
        }
    }

    public void sgeqrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray iNDArray4 = iNDArray;
        INDArray iNDArray5 = iNDArray2;
        if (Nd4j.dataType() != DataType.FLOAT) {
            log.warn("FLOAT getrf called in DOUBLE environment");
        }
        if (iNDArray.ordering() == 'c') {
            iNDArray4 = iNDArray.dup('f');
        }
        if (iNDArray2 != null && iNDArray2.ordering() == 'c') {
            iNDArray5 = iNDArray2.dup('f');
        }
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createFloat(i2), (DataBuffer) Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, i2}, iNDArray.dataType()).getFirst());
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            if (cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream())) != 0) {
                throw new IllegalStateException("solverSetStream failed");
            }
            CublasPointer cublasPointer = new CublasPointer(iNDArray4, cudaContext);
            CublasPointer cublasPointer2 = new CublasPointer(createArrayFromShapeBuffer, cudaContext);
            DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(1L);
            int cusolverDnSgeqrf_bufferSize = cusolver.cusolverDnSgeqrf_bufferSize(cusolverdncontext, i, i2, cublasPointer.getDevicePointer(), i, createInt.addressPointer());
            if (cusolverDnSgeqrf_bufferSize != 0) {
                throw new BlasException("cusolverDnSgeqrf_bufferSize failed", cusolverDnSgeqrf_bufferSize);
            }
            int cusolverDnSgeqrf = cusolver.cusolverDnSgeqrf(cusolverdncontext, i, i2, cublasPointer.getDevicePointer(), i, cublasPointer2.getDevicePointer(), new CudaPointer(new Workspace(r0 * Nd4j.sizeOfDataType())).asFloatPointer(), createInt.getInt(0L), new CudaPointer(this.allocator.getPointer(iNDArray3, cudaContext)).asIntPointer());
            if (cusolverDnSgeqrf != 0) {
                throw new BlasException("cusolverDnSgeqrf failed", cusolverDnSgeqrf);
            }
            this.allocator.registerAction(cudaContext, iNDArray4, new INDArray[0]);
            this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
            if (iNDArray3.getInt(new int[]{0}) != 0) {
                throw new BlasException("cusolverDnSgeqrf failed on INFO", iNDArray3.getInt(new int[]{0}));
            }
            if (iNDArray5 != null) {
                iNDArray5.assign(iNDArray4.get(new INDArrayIndex[]{NDArrayIndex.interval(0, iNDArray4.columns()), NDArrayIndex.all()}));
                INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[2];
                for (int i3 = 1; i3 < Math.min(iNDArray4.rows(), iNDArray4.columns()); i3++) {
                    iNDArrayIndexArr[0] = NDArrayIndex.point(i3);
                    iNDArrayIndexArr[1] = NDArrayIndex.interval(0, i3);
                    iNDArray5.put(iNDArrayIndexArr, 0);
                }
            }
            cusolver.cusolverDnSorgqr_bufferSize(cusolverdncontext, i, i2, i2, cublasPointer.getDevicePointer(), i, cublasPointer2.getDevicePointer(), createInt.addressPointer());
            int cusolverDnSorgqr = cusolver.cusolverDnSorgqr(cusolverdncontext, i, i2, i2, cublasPointer.getDevicePointer(), i, cublasPointer2.getDevicePointer(), new CudaPointer(new Workspace(r0 * Nd4j.sizeOfDataType())).asFloatPointer(), createInt.getInt(0L), new CudaPointer(this.allocator.getPointer(iNDArray3, cudaContext)).asIntPointer());
            if (cusolverDnSorgqr != 0) {
                throw new BlasException("cusolverDnSorgqr failed", cusolverDnSorgqr);
            }
        }
        this.allocator.registerAction(cudaContext, iNDArray4, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
        if (iNDArray4 != iNDArray) {
            iNDArray.assign(iNDArray4);
        }
        if (iNDArray5 != null && iNDArray5 != iNDArray2) {
            iNDArray2.assign(iNDArray5);
        }
        log.info("A: {}", iNDArray);
        if (iNDArray2 != null) {
            log.info("R: {}", iNDArray2);
        }
    }

    public void dgeqrf(int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray iNDArray4 = iNDArray;
        INDArray iNDArray5 = iNDArray2;
        if (Nd4j.dataType() != DataType.DOUBLE) {
            log.warn("DOUBLE getrf called in FLOAT environment");
        }
        if (iNDArray.ordering() == 'c') {
            iNDArray4 = iNDArray.dup('f');
        }
        if (iNDArray2 != null && iNDArray2.ordering() == 'c') {
            iNDArray5 = iNDArray2.dup('f');
        }
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createDouble(i2), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, i2}, iNDArray.dataType()));
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            if (cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream())) != 0) {
                throw new BlasException("solverSetStream failed");
            }
            CublasPointer cublasPointer = new CublasPointer(iNDArray4, cudaContext);
            CublasPointer cublasPointer2 = new CublasPointer(createArrayFromShapeBuffer, cudaContext);
            DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(1L);
            int cusolverDnDgeqrf_bufferSize = cusolver.cusolverDnDgeqrf_bufferSize(cusolverdncontext, i, i2, cublasPointer.getDevicePointer(), i, createInt.addressPointer());
            if (cusolverDnDgeqrf_bufferSize != 0) {
                throw new BlasException("cusolverDnDgeqrf_bufferSize failed", cusolverDnDgeqrf_bufferSize);
            }
            int cusolverDnDgeqrf = cusolver.cusolverDnDgeqrf(cusolverdncontext, i, i2, cublasPointer.getDevicePointer(), i, cublasPointer2.getDevicePointer(), new CudaPointer(new Workspace(r0 * Nd4j.sizeOfDataType())).asDoublePointer(), createInt.getInt(0L), new CudaPointer(this.allocator.getPointer(iNDArray3, cudaContext)).asIntPointer());
            if (cusolverDnDgeqrf != 0) {
                throw new BlasException("cusolverDnDgeqrf failed", cusolverDnDgeqrf);
            }
            this.allocator.registerAction(cudaContext, iNDArray4, new INDArray[0]);
            this.allocator.registerAction(cudaContext, createArrayFromShapeBuffer, new INDArray[0]);
            this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
            if (iNDArray3.getInt(new int[]{0}) != 0) {
                throw new BlasException("cusolverDnDgeqrf failed with info", iNDArray3.getInt(new int[]{0}));
            }
            if (iNDArray5 != null) {
                iNDArray5.assign(iNDArray4.get(new INDArrayIndex[]{NDArrayIndex.interval(0, iNDArray4.columns()), NDArrayIndex.all()}));
                INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[2];
                for (int i3 = 1; i3 < Math.min(iNDArray4.rows(), iNDArray4.columns()); i3++) {
                    iNDArrayIndexArr[0] = NDArrayIndex.point(i3);
                    iNDArrayIndexArr[1] = NDArrayIndex.interval(0, i3);
                    iNDArray5.put(iNDArrayIndexArr, 0);
                }
            }
            cusolver.cusolverDnDorgqr_bufferSize(cusolverdncontext, i, i2, i2, cublasPointer.getDevicePointer(), i, cublasPointer2.getDevicePointer(), createInt.addressPointer());
            int cusolverDnDorgqr = cusolver.cusolverDnDorgqr(cusolverdncontext, i, i2, i2, cublasPointer.getDevicePointer(), i, cublasPointer2.getDevicePointer(), new CudaPointer(new Workspace(r0 * Nd4j.sizeOfDataType())).asDoublePointer(), createInt.getInt(0L), new CudaPointer(this.allocator.getPointer(iNDArray3, cudaContext)).asIntPointer());
            if (cusolverDnDorgqr != 0) {
                throw new BlasException("cusolverDnDorgqr failed", cusolverDnDorgqr);
            }
        }
        this.allocator.registerAction(cudaContext, iNDArray4, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
        if (iNDArray4 != iNDArray) {
            iNDArray.assign(iNDArray4);
        }
        if (iNDArray5 != null && iNDArray5 != iNDArray2) {
            iNDArray2.assign(iNDArray5);
        }
        log.info("A: {}", iNDArray);
        if (iNDArray2 != null) {
            log.info("R: {}", iNDArray2);
        }
    }

    public void spotrf(byte b, int i, INDArray iNDArray, INDArray iNDArray2) {
        INDArray iNDArray3 = iNDArray;
        if (iNDArray.dataType() != DataType.FLOAT) {
            log.warn("FLOAT potrf called for " + iNDArray.dataType());
        }
        if (iNDArray.ordering() == 'c') {
            iNDArray3 = iNDArray.dup('f');
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            if (cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream())) != 0) {
                throw new BlasException("solverSetStream failed");
            }
            CublasPointer cublasPointer = new CublasPointer(iNDArray3, cudaContext);
            DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(1L);
            int cusolverDnSpotrf_bufferSize = cusolver.cusolverDnSpotrf_bufferSize(cusolverdncontext, b, i, cublasPointer.getDevicePointer(), i, createInt.addressPointer());
            if (cusolverDnSpotrf_bufferSize != 0) {
                throw new BlasException("cusolverDnSpotrf_bufferSize failed", cusolverDnSpotrf_bufferSize);
            }
            int cusolverDnSpotrf = cusolver.cusolverDnSpotrf(cusolverdncontext, b, i, cublasPointer.getDevicePointer(), i, new CudaPointer(new Workspace(r0 * Nd4j.sizeOfDataType())).asFloatPointer(), createInt.getInt(0L), new CudaPointer(this.allocator.getPointer(iNDArray2, cudaContext)).asIntPointer());
            if (cusolverDnSpotrf != 0) {
                throw new BlasException("cusolverDnSpotrf failed", cusolverDnSpotrf);
            }
        }
        this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray2, new INDArray[0]);
        if (iNDArray3 != iNDArray) {
            iNDArray.assign(iNDArray3);
        }
        if (b == 85) {
            iNDArray.assign(iNDArray.transpose());
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[2];
            for (int i2 = 1; i2 < Math.min(iNDArray.rows(), iNDArray.columns()); i2++) {
                iNDArrayIndexArr[0] = NDArrayIndex.point(i2);
                iNDArrayIndexArr[1] = NDArrayIndex.interval(0, i2);
                iNDArray.put(iNDArrayIndexArr, 0);
            }
        } else {
            INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[2];
            for (int i3 = 0; i3 < Math.min(iNDArray.rows(), iNDArray.columns() - 1); i3++) {
                iNDArrayIndexArr2[0] = NDArrayIndex.point(i3);
                iNDArrayIndexArr2[1] = NDArrayIndex.interval(i3 + 1, iNDArray.columns());
                iNDArray.put(iNDArrayIndexArr2, 0);
            }
        }
        log.info("A: {}", iNDArray);
    }

    public void dpotrf(byte b, int i, INDArray iNDArray, INDArray iNDArray2) {
        INDArray iNDArray3 = iNDArray;
        if (iNDArray.dataType() != DataType.DOUBLE) {
            log.warn("DOUBLE potrf called for " + iNDArray.dataType());
        }
        if (iNDArray.ordering() == 'c') {
            iNDArray3 = iNDArray.dup('f');
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            if (cusolver.cusolverDnSetStream(cusolverdncontext, new CUstream_st(cudaContext.getOldStream())) != 0) {
                throw new BlasException("solverSetStream failed");
            }
            CublasPointer cublasPointer = new CublasPointer(iNDArray3, cudaContext);
            DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(1L);
            int cusolverDnDpotrf_bufferSize = cusolver.cusolverDnDpotrf_bufferSize(cusolverdncontext, b, i, cublasPointer.getDevicePointer(), i, createInt.addressPointer());
            if (cusolverDnDpotrf_bufferSize != 0) {
                throw new BlasException("cusolverDnDpotrf_bufferSize failed", cusolverDnDpotrf_bufferSize);
            }
            int cusolverDnDpotrf = cusolver.cusolverDnDpotrf(cusolverdncontext, b, i, cublasPointer.getDevicePointer(), i, new CudaPointer(new Workspace(r0 * Nd4j.sizeOfDataType(DataType.DOUBLE))).asDoublePointer(), createInt.getInt(0L), new CudaPointer(this.allocator.getPointer(iNDArray2, cudaContext)).asIntPointer());
            if (cusolverDnDpotrf != 0) {
                throw new BlasException("cusolverDnDpotrf failed", cusolverDnDpotrf);
            }
        }
        this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray2, new INDArray[0]);
        if (iNDArray3 != iNDArray) {
            iNDArray.assign(iNDArray3);
        }
        if (b == 85) {
            iNDArray.assign(iNDArray.transpose());
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[2];
            for (int i2 = 1; i2 < Math.min(iNDArray.rows(), iNDArray.columns()); i2++) {
                iNDArrayIndexArr[0] = NDArrayIndex.point(i2);
                iNDArrayIndexArr[1] = NDArrayIndex.interval(0, i2);
                iNDArray.put(iNDArrayIndexArr, 0);
            }
        } else {
            INDArrayIndex[] iNDArrayIndexArr2 = new INDArrayIndex[2];
            for (int i3 = 0; i3 < Math.min(iNDArray.rows(), iNDArray.columns() - 1); i3++) {
                iNDArrayIndexArr2[0] = NDArrayIndex.point(i3);
                iNDArrayIndexArr2[1] = NDArrayIndex.interval(i3 + 1, iNDArray.columns());
                iNDArray.put(iNDArrayIndexArr2, 0);
            }
        }
        log.info("A: {}", iNDArray);
    }

    public void getri(int i, INDArray iNDArray, int i2, int[] iArr, INDArray iNDArray2, int i3, int i4) {
    }

    public void sgesvd(byte b, byte b2, int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5) {
        if (Nd4j.dataType() != DataType.FLOAT) {
            log.warn("FLOAT gesvd called in DOUBLE environment");
        }
        INDArray iNDArray6 = iNDArray;
        INDArray iNDArray7 = iNDArray3;
        INDArray iNDArray8 = iNDArray4;
        boolean z = false;
        if (i < i2) {
            z = true;
            i2 = i;
            i = i2;
            b = b2;
            b2 = b;
            iNDArray6 = iNDArray.transpose().dup('f');
            iNDArray7 = iNDArray4 == null ? null : iNDArray4.transpose().dup('f');
            iNDArray8 = iNDArray3 == null ? null : iNDArray3.transpose().dup('f');
        } else {
            if (iNDArray.ordering() == 'c') {
                iNDArray6 = iNDArray.dup('f');
            }
            if (iNDArray3 != null && iNDArray3.ordering() == 'c') {
                iNDArray7 = iNDArray3.dup('f');
            }
            if (iNDArray4 != null && iNDArray4.ordering() == 'c') {
                iNDArray8 = iNDArray4.dup('f');
            }
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            if (cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream())) != 0) {
                throw new BlasException("solverSetStream failed");
            }
            CublasPointer cublasPointer = new CublasPointer(iNDArray6, cudaContext);
            DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(1L);
            int cusolverDnSgesvd_bufferSize = cusolver.cusolverDnSgesvd_bufferSize(cusolverdncontext, i, i2, createInt.addressPointer());
            if (cusolverDnSgesvd_bufferSize != 0) {
                throw new BlasException("cusolverDnSgesvd_bufferSize failed", cusolverDnSgesvd_bufferSize);
            }
            int cusolverDnSgesvd = cusolver.cusolverDnSgesvd(cusolverdncontext, b, b2, i, i2, cublasPointer.getDevicePointer(), i, new CudaPointer(this.allocator.getPointer(iNDArray2, cudaContext)).asFloatPointer(), iNDArray7 == null ? null : new CudaPointer(this.allocator.getPointer(iNDArray7, cudaContext)).asFloatPointer(), i, iNDArray8 == null ? null : new CudaPointer(this.allocator.getPointer(iNDArray8, cudaContext)).asFloatPointer(), i2, new CudaPointer(new Workspace(r0 * Nd4j.sizeOfDataType())).asFloatPointer(), createInt.getInt(0L), new CudaPointer(this.allocator.getPointer(Nd4j.getDataBufferFactory().createFloat((i < i2 ? i : i2) - 1), cudaContext)).asFloatPointer(), new CudaPointer(this.allocator.getPointer(iNDArray5, cudaContext)).asIntPointer());
            if (cusolverDnSgesvd != 0) {
                throw new BlasException("cusolverDnSgesvd failed", cusolverDnSgesvd);
            }
        }
        this.allocator.registerAction(cudaContext, iNDArray5, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray2, new INDArray[0]);
        if (iNDArray7 != null) {
            this.allocator.registerAction(cudaContext, iNDArray7, new INDArray[0]);
        }
        if (iNDArray8 != null) {
            this.allocator.registerAction(cudaContext, iNDArray8, new INDArray[0]);
        }
        if (z) {
            if (iNDArray8 != null) {
                iNDArray3.assign(iNDArray8.transpose());
            }
            if (iNDArray7 != null) {
                iNDArray4.assign(iNDArray7.transpose());
                return;
            }
            return;
        }
        if (iNDArray7 != iNDArray3) {
            iNDArray3.assign(iNDArray7);
        }
        if (iNDArray8 != iNDArray4) {
            iNDArray4.assign(iNDArray8);
        }
    }

    public void dgesvd(byte b, byte b2, int i, int i2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5) {
        INDArray iNDArray6 = iNDArray;
        INDArray iNDArray7 = iNDArray3;
        INDArray iNDArray8 = iNDArray4;
        boolean z = false;
        if (i < i2) {
            z = true;
            i2 = i;
            i = i2;
            b = b2;
            b2 = b;
            iNDArray6 = iNDArray.transpose().dup('f');
            iNDArray7 = iNDArray4 == null ? null : iNDArray4.transpose().dup('f');
            iNDArray8 = iNDArray3 == null ? null : iNDArray3.transpose().dup('f');
        } else {
            if (iNDArray.ordering() == 'c') {
                iNDArray6 = iNDArray.dup('f');
            }
            if (iNDArray3 != null && iNDArray3.ordering() == 'c') {
                iNDArray7 = iNDArray3.dup('f');
            }
            if (iNDArray4 != null && iNDArray4.ordering() == 'c') {
                iNDArray8 = iNDArray4.dup('f');
            }
        }
        if (Nd4j.dataType() != DataType.DOUBLE) {
            log.warn("DOUBLE gesvd called in FLOAT environment");
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            if (cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream())) != 0) {
                throw new BlasException("solverSetStream failed");
            }
            CublasPointer cublasPointer = new CublasPointer(iNDArray6, cudaContext);
            DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(1L);
            int cusolverDnSgesvd_bufferSize = cusolver.cusolverDnSgesvd_bufferSize(cusolverdncontext, i, i2, createInt.addressPointer());
            if (cusolverDnSgesvd_bufferSize != 0) {
                throw new BlasException("cusolverDnSgesvd_bufferSize failed", cusolverDnSgesvd_bufferSize);
            }
            int cusolverDnDgesvd = cusolver.cusolverDnDgesvd(cusolverdncontext, b, b2, i, i2, cublasPointer.getDevicePointer(), i, new CudaPointer(this.allocator.getPointer(iNDArray2, cudaContext)).asDoublePointer(), iNDArray7 == null ? null : new CudaPointer(this.allocator.getPointer(iNDArray7, cudaContext)).asDoublePointer(), i, iNDArray8 == null ? null : new CudaPointer(this.allocator.getPointer(iNDArray8, cudaContext)).asDoublePointer(), i2, new CudaPointer(new Workspace(r0 * Nd4j.sizeOfDataType())).asDoublePointer(), createInt.getInt(0L), new CudaPointer(this.allocator.getPointer(Nd4j.getDataBufferFactory().createDouble((i < i2 ? i : i2) - 1), cudaContext)).asDoublePointer(), new CudaPointer(this.allocator.getPointer(iNDArray5, cudaContext)).asIntPointer());
            if (cusolverDnDgesvd != 0) {
                throw new BlasException("cusolverDnDgesvd failed" + cusolverDnDgesvd);
            }
        }
        this.allocator.registerAction(cudaContext, iNDArray5, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray2, new INDArray[0]);
        this.allocator.registerAction(cudaContext, iNDArray6, new INDArray[0]);
        if (iNDArray7 != null) {
            this.allocator.registerAction(cudaContext, iNDArray7, new INDArray[0]);
        }
        if (iNDArray8 != null) {
            this.allocator.registerAction(cudaContext, iNDArray8, new INDArray[0]);
        }
        if (z) {
            if (iNDArray8 != null) {
                iNDArray3.assign(iNDArray8.transpose());
            }
            if (iNDArray7 != null) {
                iNDArray4.assign(iNDArray7.transpose());
                return;
            }
            return;
        }
        if (iNDArray7 != iNDArray3) {
            iNDArray3.assign(iNDArray7);
        }
        if (iNDArray8 != iNDArray4) {
            iNDArray4.assign(iNDArray8);
        }
    }

    public int ssyev(char c, char c2, int i, INDArray iNDArray, INDArray iNDArray2) {
        int cusolverDnSetStream;
        int i2 = c == 'V' ? 1 : 0;
        int i3 = c2 == 'L' ? 0 : 1;
        if (Nd4j.dataType() != DataType.FLOAT) {
            log.warn("FLOAT ssyev called in DOUBLE environment");
        }
        INDArray iNDArray3 = iNDArray;
        if (iNDArray.ordering() == 'c') {
            iNDArray3 = iNDArray.dup('f');
        }
        int rows = iNDArray.rows();
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            cusolverDnSetStream = cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream()));
            if (cusolverDnSetStream == 0) {
                CublasPointer cublasPointer = new CublasPointer(iNDArray3, cudaContext);
                CublasPointer cublasPointer2 = new CublasPointer(iNDArray2, cudaContext);
                DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(1L);
                cusolverDnSetStream = cusolver.cusolverDnSsyevd_bufferSize(cusolverdncontext, i2, i3, rows, cublasPointer.getDevicePointer(), rows, cublasPointer2.getDevicePointer(), createInt.addressPointer());
                if (cusolverDnSetStream == 0) {
                    int i4 = createInt.getInt(0L);
                    Workspace workspace = new Workspace(i4 * Nd4j.sizeOfDataType());
                    INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1L), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, iNDArray.dataType()));
                    cusolverDnSetStream = cusolver.cusolverDnSsyevd(cusolverdncontext, i2, i3, rows, cublasPointer.getDevicePointer(), rows, cublasPointer2.getDevicePointer(), new CudaPointer(workspace).asFloatPointer(), i4, new CudaPointer(this.allocator.getPointer(createArrayFromShapeBuffer, cudaContext)).asIntPointer());
                    this.allocator.registerAction(cudaContext, createArrayFromShapeBuffer, new INDArray[0]);
                    if (cusolverDnSetStream == 0) {
                        cusolverDnSetStream = createArrayFromShapeBuffer.getInt(new int[]{0});
                    }
                }
            }
        }
        if (cusolverDnSetStream == 0) {
            this.allocator.registerAction(cudaContext, iNDArray2, new INDArray[0]);
            this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
            if (iNDArray3 != iNDArray) {
                iNDArray.assign(iNDArray3);
            }
        }
        return cusolverDnSetStream;
    }

    public int dsyev(char c, char c2, int i, INDArray iNDArray, INDArray iNDArray2) {
        int cusolverDnSetStream;
        int i2 = c == 'V' ? 1 : 0;
        int i3 = c2 == 'L' ? 0 : 1;
        if (Nd4j.dataType() != DataType.DOUBLE) {
            log.warn("DOUBLE dsyev called in FLOAT environment");
        }
        INDArray iNDArray3 = iNDArray;
        if (iNDArray.ordering() == 'c') {
            iNDArray3 = iNDArray.dup('f');
        }
        int rows = iNDArray.rows();
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t solverHandle = cudaContext.getSolverHandle();
        cusolverDnContext cusolverdncontext = new cusolverDnContext(solverHandle);
        synchronized (solverHandle) {
            cusolverDnSetStream = cusolver.cusolverDnSetStream(new cusolverDnContext(solverHandle), new CUstream_st(cudaContext.getOldStream()));
            if (cusolverDnSetStream == 0) {
                CublasPointer cublasPointer = new CublasPointer(iNDArray3, cudaContext);
                CublasPointer cublasPointer2 = new CublasPointer(iNDArray2, cudaContext);
                DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(1L);
                cusolverDnSetStream = cusolver.cusolverDnDsyevd_bufferSize(cusolverdncontext, i2, i3, rows, cublasPointer.getDevicePointer(), rows, cublasPointer2.getDevicePointer(), createInt.addressPointer());
                if (cusolverDnSetStream == 0) {
                    int i4 = createInt.getInt(0L);
                    Workspace workspace = new Workspace(i4 * Nd4j.sizeOfDataType());
                    INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1L), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, iNDArray.dataType()));
                    cusolverDnSetStream = cusolver.cusolverDnDsyevd(cusolverdncontext, i2, i3, rows, cublasPointer.getDevicePointer(), rows, cublasPointer2.getDevicePointer(), new CudaPointer(workspace).asDoublePointer(), i4, new CudaPointer(this.allocator.getPointer(createArrayFromShapeBuffer, cudaContext)).asIntPointer());
                    this.allocator.registerAction(cudaContext, createArrayFromShapeBuffer, new INDArray[0]);
                    if (cusolverDnSetStream == 0) {
                        cusolverDnSetStream = createArrayFromShapeBuffer.getInt(new int[]{0});
                    }
                }
            }
        }
        if (cusolverDnSetStream == 0) {
            this.allocator.registerAction(cudaContext, iNDArray2, new INDArray[0]);
            this.allocator.registerAction(cudaContext, iNDArray3, new INDArray[0]);
            if (iNDArray3 != iNDArray) {
                iNDArray.assign(iNDArray3);
            }
        }
        return cusolverDnSetStream;
    }
}
