package org.nd4j.linalg.jcublas;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Iterator;
import jcuda.LibUtils;
import jcuda.LogLevel;
import jcuda.Pointer;
import jcuda.cuComplex;
import jcuda.cuDoubleComplex;
import jcuda.jcublas.JCublas;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.DataTypeValidation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.complex.JCublasComplexNDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/nd4j/linalg/jcublas/SimpleJCublas.class */
public class SimpleJCublas {
    public static final String CUDA_HOME = "CUDA_HOME";
    public static final String JCUDA_HOME_PROP = "jcuda.home";
    private static Logger log = LoggerFactory.getLogger(SimpleJCublas.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.linalg.jcublas.SimpleJCublas$2, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/linalg/jcublas/SimpleJCublas$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$jcuda$LibUtils$ARCHType;
        static final /* synthetic */ int[] $SwitchMap$jcuda$LibUtils$OSType = new int[LibUtils.OSType.values().length];

        static {
            try {
                $SwitchMap$jcuda$LibUtils$OSType[LibUtils.OSType.APPLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$jcuda$LibUtils$OSType[LibUtils.OSType.LINUX.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$jcuda$LibUtils$OSType[LibUtils.OSType.SUN.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$jcuda$LibUtils$OSType[LibUtils.OSType.WINDOWS.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$jcuda$LibUtils$ARCHType = new int[LibUtils.ARCHType.values().length];
            try {
                $SwitchMap$jcuda$LibUtils$ARCHType[LibUtils.ARCHType.X86_64.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$jcuda$LibUtils$ARCHType[LibUtils.ARCHType.PPC_64.ordinal()] = 2;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    private static String libDir() {
        int thirtyTwoOrSixtyFour = thirtyTwoOrSixtyFour();
        String str = cudaBase() + File.separator + libFolder();
        String str2 = str + (thirtyTwoOrSixtyFour == 64 ? "64" : "");
        boolean exists = new File(str2).exists();
        if (exists) {
            return str2;
        }
        if (exists) {
            File file = new File(str);
            if (exists || file.exists()) {
                return str;
            }
            throw new IllegalStateException("No lib directory found");
        }
        File file2 = new File(str);
        if (thirtyTwoOrSixtyFour != 64 || !file2.exists()) {
            return str2;
        }
        log.warn("Loading 32 bit cuda...no 64 bit found");
        return str;
    }

    private static String libFolder() {
        return LibUtils.calculateOS() == LibUtils.OSType.WINDOWS ? "Lib" : "lib";
    }

    private static int thirtyTwoOrSixtyFour() {
        switch (AnonymousClass2.$SwitchMap$jcuda$LibUtils$ARCHType[LibUtils.calculateArch().ordinal()]) {
            case 1:
                return 64;
            case 2:
                return 64;
            default:
                return 32;
        }
    }

    private static String cudaBase() {
        String property = System.getProperty(JCUDA_HOME_PROP, System.getenv(CUDA_HOME));
        if (property != null) {
            return property;
        }
        throw new IllegalStateException("Please specify a cuda home property in your environment (export CUDA_HOME=/path/to/your/dir) or via -Djcuda.home=/path/to/your/dir");
    }

    private static String resourceName() {
        LibUtils.OSType calculateOS = LibUtils.calculateOS();
        LibUtils.ARCHType calculateArch = LibUtils.calculateArch();
        switch (AnonymousClass2.$SwitchMap$jcuda$LibUtils$OSType[calculateOS.ordinal()]) {
            case 1:
                return String.format("libJCublas-apple-%s.dylib", calculateArch.toString());
            case 2:
                return String.format("libJCublas-linux-%s.so", calculateArch.toString());
            case 3:
                return String.format("libJCublas-linux-%s.so", calculateArch.toString());
            case 4:
                return String.format("libJCublas-windows-%s.dll", calculateArch.toString());
            default:
                return null;
        }
    }

    private static String findWritableLibDir() {
        for (String str : System.getProperty("java.library.path").split(File.pathSeparator)) {
            if (canWrite(new File(str))) {
                return str;
            }
        }
        throw new IllegalStateException("Unable to write to any library directories for jcublas");
    }

    private static boolean canWrite(File file) {
        if (!file.exists()) {
            return false;
        }
        if (file.isFile()) {
            throw new IllegalArgumentException("Tests only directories");
        }
        File file2 = new File(file, "dummyfile");
        file2.deleteOnExit();
        try {
            return file2.createNewFile();
        } catch (IOException e) {
            return false;
        }
    }

    public static void free(Pointer... pointerArr) {
        for (Pointer pointer : pointerArr) {
            JCublas.cublasFree(pointer);
        }
    }

    private static int size(INDArray iNDArray) {
        return iNDArray.data().dataType() == 1 ? 4 : 8;
    }

    public static void getData(JCublasNDArray jCublasNDArray, Pointer pointer, Pointer pointer2) {
        if (jCublasNDArray.length() == jCublasNDArray.data().length()) {
            JCublas.cublasGetVector(jCublasNDArray.length(), size(jCublasNDArray), pointer, 1, pointer2.withByteOffset(jCublasNDArray.offset() * size(jCublasNDArray)), 1);
        } else {
            JCublas.cublasGetVector(jCublasNDArray.length(), size(jCublasNDArray), pointer, 1, pointer2.withByteOffset(jCublasNDArray.offset() * size(jCublasNDArray)), jCublasNDArray.majorStride());
        }
    }

    public static Pointer alloc(JCublasComplexNDArray jCublasComplexNDArray) {
        Pointer pointer = new Pointer();
        JCublas.cublasAlloc(jCublasComplexNDArray.length() * 2, size(jCublasComplexNDArray), pointer);
        Pointer withByteOffset = jCublasComplexNDArray.data().dataType() == 1 ? Pointer.to(jCublasComplexNDArray.data().asFloat()).withByteOffset(jCublasComplexNDArray.offset() * size(jCublasComplexNDArray)) : Pointer.to(jCublasComplexNDArray.data().asDouble()).withByteOffset(jCublasComplexNDArray.offset() * size(jCublasComplexNDArray));
        if (jCublasComplexNDArray.length() == jCublasComplexNDArray.data().length()) {
            JCublas.cublasSetVector(jCublasComplexNDArray.length() * 2, size(jCublasComplexNDArray), withByteOffset, 1, pointer, 1);
        } else {
            JCublas.cublasSetVector(jCublasComplexNDArray.length() * 2, size(jCublasComplexNDArray), withByteOffset, 1, pointer, 1);
        }
        return pointer;
    }

    public static void getData(JCublasComplexNDArray jCublasComplexNDArray, Pointer pointer, Pointer pointer2) {
        if (jCublasComplexNDArray.length() == jCublasComplexNDArray.data().length()) {
            JCublas.cublasGetVector(jCublasComplexNDArray.length() * 2, size(jCublasComplexNDArray), pointer, 1, pointer2.withByteOffset(jCublasComplexNDArray.offset() * size(jCublasComplexNDArray)), 1);
        } else {
            JCublas.cublasGetVector(jCublasComplexNDArray.length() * 2, 4, pointer, 1, pointer2.withByteOffset(jCublasComplexNDArray.offset() * size(jCublasComplexNDArray)), 1);
        }
    }

    public static Pointer alloc(JCublasNDArray jCublasNDArray) {
        Pointer pointer = new Pointer();
        Pointer withByteOffset = jCublasNDArray.data().dataType() == 1 ? Pointer.to(jCublasNDArray.data().asFloat()).withByteOffset(jCublasNDArray.offset() * size(jCublasNDArray)) : Pointer.to(jCublasNDArray.data().asDouble()).withByteOffset(jCublasNDArray.offset() * size(jCublasNDArray));
        JCublas.cublasAlloc(jCublasNDArray.length(), size(jCublasNDArray), pointer);
        if (jCublasNDArray.length() == jCublasNDArray.data().length()) {
            JCublas.cublasSetVector(jCublasNDArray.length(), size(jCublasNDArray), withByteOffset, 1, pointer, 1);
        } else {
            JCublas.cublasSetVector(jCublasNDArray.length(), size(jCublasNDArray), withByteOffset, jCublasNDArray.majorStride(), pointer, 1);
        }
        return pointer;
    }

    public static INDArray gemv(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, double d, double d2) {
        DataTypeValidation.assertDouble(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasDgemv('N', iNDArray.rows(), iNDArray.columns(), d, alloc, iNDArray.rows(), alloc2, 1, d2, alloc3, 1);
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data().asDouble()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    public static INDArray gemv(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f, float f2) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasSgemv('N', iNDArray.rows(), iNDArray.columns(), f, alloc, iNDArray.rows(), alloc2, 1, f2, alloc3, 1);
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data().asFloat()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gemv(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        JCublas.cublasInit();
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray);
        JCublas.cublasZgemv('n', iComplexNDArray.rows(), iComplexNDArray.rows(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble2.imaginaryComponent().doubleValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.secondaryStride(), cuDoubleComplex.cuCmplx(iComplexDouble2.realComponent().doubleValue(), iComplexDouble2.imaginaryComponent().doubleValue()), alloc3, iComplexNDArray3.secondaryStride());
        getData(jCublasComplexNDArray, alloc3, Pointer.to(jCublasComplexNDArray.data().asDouble()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gemv(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray3, IComplexFloat iComplexFloat2) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        JCublas.cublasInit();
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray2;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc(jCublasComplexNDArray);
        Pointer alloc3 = alloc(jCublasComplexNDArray2);
        JCublas.cublasCgemv('n', iComplexNDArray.rows(), iComplexNDArray.columns(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat2.imaginaryComponent().floatValue()), alloc, iComplexNDArray.rows(), alloc2, jCublasComplexNDArray.secondaryStride(), cuComplex.cuCmplx(iComplexFloat2.realComponent().floatValue(), iComplexFloat2.imaginaryComponent().floatValue()), alloc3, jCublasComplexNDArray2.secondaryStride());
        getData(jCublasComplexNDArray2, alloc3, Pointer.to(jCublasComplexNDArray2.data().asFloat()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gemm(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        JCublas.cublasInit();
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc(jCublasComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray2);
        JCublas.cublasZgemm('n', 'n', jCublasComplexNDArray2.rows(), jCublasComplexNDArray2.columns(), jCublasComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble2.imaginaryComponent().doubleValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), cuDoubleComplex.cuCmplx(iComplexDouble2.realComponent().doubleValue(), iComplexDouble2.imaginaryComponent().doubleValue()), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray2, alloc3, Pointer.to(jCublasComplexNDArray2.data().asDouble()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gemm(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray3, IComplexFloat iComplexFloat2) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        JCublas.cublasInit();
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc(jCublasComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray2);
        JCublas.cublasCgemm('n', 'n', jCublasComplexNDArray2.rows(), jCublasComplexNDArray2.columns(), jCublasComplexNDArray.columns(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat2.imaginaryComponent().floatValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), cuComplex.cuCmplx(iComplexFloat2.realComponent().floatValue(), iComplexFloat2.imaginaryComponent().floatValue()), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray2, alloc3, Pointer.to(jCublasComplexNDArray2.data().asFloat()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    public static INDArray gemm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, double d, double d2) {
        DataTypeValidation.assertDouble(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasDgemm('n', 'n', iNDArray3.rows(), iNDArray3.columns(), iNDArray.columns(), d, alloc, iNDArray.rows(), alloc2, iNDArray2.rows(), d2, alloc3, iNDArray3.rows());
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data().asDouble()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    public static INDArray gemm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f, float f2) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasSgemm('n', 'n', iNDArray3.rows(), iNDArray3.columns(), iNDArray.columns(), f, alloc, iNDArray.rows(), alloc2, iNDArray2.rows(), f2, alloc3, iNDArray3.rows());
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data().asFloat()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    public static double nrm2(IComplexNDArray iComplexNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        if (iComplexNDArray.data().dataType() == 1) {
            float cublasSnrm2 = JCublas.cublasSnrm2(iComplexNDArray.length(), alloc, 2);
            free(alloc);
            return cublasSnrm2;
        }
        double cublasDnrm2 = JCublas.cublasDnrm2(iComplexNDArray.length(), alloc, 2);
        free(alloc);
        return cublasDnrm2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void copy(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        JCublas.cublasInit();
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        Pointer alloc = alloc(jCublasComplexNDArray);
        Pointer alloc2 = alloc(jCublasComplexNDArray2);
        if (jCublasComplexNDArray.data().dataType() == 1) {
            JCublas.cublasScopy(iComplexNDArray.length(), alloc, 1, alloc2, 1);
            getData(jCublasComplexNDArray2, alloc2, Pointer.to(jCublasComplexNDArray2.data().asFloat()));
        } else {
            JCublas.cublasDcopy(iComplexNDArray.length(), alloc, 1, alloc2, 1);
            getData(jCublasComplexNDArray2, alloc2, Pointer.to(jCublasComplexNDArray2.data().asDouble()));
        }
        free(alloc, alloc2);
    }

    public static int iamax(IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        Pointer alloc = alloc(jCublasComplexNDArray);
        if (jCublasComplexNDArray.data().dataType() == 1) {
            int cublasIsamax = JCublas.cublasIsamax(iComplexNDArray.length(), alloc, 1);
            free(alloc);
            return cublasIsamax;
        }
        int cublasIzamax = JCublas.cublasIzamax(iComplexNDArray.length(), alloc, 1);
        free(alloc);
        return cublasIzamax;
    }

    public static float asum(IComplexNDArray iComplexNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        float cublasScasum = JCublas.cublasScasum(iComplexNDArray.length(), alloc, 1);
        free(alloc);
        return cublasScasum;
    }

    public static void swap(INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iNDArray, iNDArray2});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc(jCublasNDArray);
        Pointer alloc2 = alloc(jCublasNDArray2);
        if (jCublasNDArray.data().dataType() == 1) {
            JCublas.cublasSswap(jCublasNDArray.length(), alloc, 1, alloc2, 1);
            getData(jCublasNDArray2, alloc2, Pointer.to(jCublasNDArray2.data().asFloat()));
        } else {
            JCublas.cublasDswap(jCublasNDArray.length(), alloc, 1, alloc2, 1);
            getData(jCublasNDArray2, alloc2, Pointer.to(jCublasNDArray2.data().asDouble()));
        }
        free(alloc, alloc2);
    }

    public static double asum(INDArray iNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        if (iNDArray.data().dataType() == 1) {
            float cublasSasum = JCublas.cublasSasum(iNDArray.length(), alloc, 1);
            free(alloc);
            return cublasSasum;
        }
        double cublasDasum = JCublas.cublasDasum(iNDArray.length(), alloc, 1);
        free(alloc);
        return cublasDasum;
    }

    public static float nrm2(INDArray iNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        float cublasSnrm2 = JCublas.cublasSnrm2(iNDArray.length(), alloc, 1);
        JCublas.cublasFree(alloc);
        return cublasSnrm2;
    }

    public static int iamax(INDArray iNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        int cublasIsamax = JCublas.cublasIsamax(iNDArray.length(), alloc, 1);
        free(alloc);
        return cublasIsamax - 1;
    }

    public static void axpy(float f, INDArray iNDArray, INDArray iNDArray2) {
        JCublas.cublasInit();
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2});
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc(jCublasNDArray);
        Pointer alloc2 = alloc(jCublasNDArray2);
        if (jCublasNDArray.ordering() == 'c') {
            JCublas.cublasSaxpy(jCublasNDArray.length(), f, alloc, 1, alloc2, 1);
            getData(jCublasNDArray2, alloc2, Pointer.to(jCublasNDArray2.data().asFloat()));
        } else {
            JCublas.cublasSaxpy(jCublasNDArray.length(), f, alloc, 1, alloc2, 1);
            getData(jCublasNDArray2, alloc2, Pointer.to(jCublasNDArray2.data().asFloat()));
        }
        free(alloc, alloc2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void axpy(IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        JCublas.cublasInit();
        Pointer alloc = alloc(jCublasComplexNDArray);
        Pointer alloc2 = alloc(jCublasComplexNDArray2);
        JCublas.cublasCaxpy(jCublasComplexNDArray.length(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), alloc, 1, alloc2, 1);
        getData(jCublasComplexNDArray2, alloc2, Pointer.to(jCublasComplexNDArray2.data().asFloat()));
        free(alloc, alloc2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void axpy(IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        DataTypeValidation.assertDouble(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        JCublas.cublasInit();
        Pointer alloc = alloc(jCublasComplexNDArray);
        Pointer alloc2 = alloc(jCublasComplexNDArray2);
        JCublas.cublasZaxpy(jCublasComplexNDArray.length(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().floatValue(), iComplexDouble.imaginaryComponent().floatValue()), alloc, 1, alloc2, 1);
        getData(jCublasComplexNDArray2, alloc2, Pointer.to(jCublasComplexNDArray2.data().asDouble()));
        free(alloc, alloc2);
    }

    public static INDArray scal(double d, INDArray iNDArray) {
        DataTypeValidation.assertDouble(iNDArray);
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        Pointer alloc = alloc(jCublasNDArray);
        JCublas.cublasDscal(jCublasNDArray.length(), d, alloc, 1);
        getData(jCublasNDArray, alloc, Pointer.to(jCublasNDArray.data().asDouble()));
        free(alloc);
        return iNDArray;
    }

    public static INDArray scal(float f, INDArray iNDArray) {
        DataTypeValidation.assertFloat(iNDArray);
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        Pointer alloc = alloc(jCublasNDArray);
        JCublas.cublasSscal(jCublasNDArray.length(), f, alloc, 1);
        getData(jCublasNDArray, alloc, Pointer.to(jCublasNDArray.data().asFloat()));
        free(alloc);
        return iNDArray;
    }

    public static void copy(INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iNDArray, iNDArray2});
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc(jCublasNDArray);
        if (iNDArray.data().dataType() == 0) {
            JCublas.cublasDcopy(iNDArray.length(), alloc, 1, alloc2, 1);
            getData(jCublasNDArray, alloc2, Pointer.to(jCublasNDArray.data().asDouble()));
        } else {
            JCublas.cublasScopy(iNDArray.length(), alloc, 1, alloc2, 1);
            getData(jCublasNDArray, alloc2, Pointer.to(jCublasNDArray.data().asFloat()));
        }
        free(alloc, alloc2);
    }

    public static double dot(INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iNDArray, iNDArray2});
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        if (iNDArray.data().dataType() == 1) {
            float cublasSdot = JCublas.cublasSdot(iNDArray.length(), alloc, 1, alloc2, 1);
            free(alloc, alloc2);
            return cublasSdot;
        }
        double cublasDdot = JCublas.cublasDdot(iNDArray.length(), alloc, 1, alloc2, 1);
        free(alloc, alloc2);
        return cublasDdot;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexDouble dot(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        cuDoubleComplex cublasZdotc = JCublas.cublasZdotc(iComplexNDArray.length(), alloc, 1, alloc2, 1);
        IComplexDouble createDouble = Nd4j.createDouble(cublasZdotc.x, cublasZdotc.y);
        free(alloc, alloc2);
        return createDouble;
    }

    public static INDArray ger(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, double d) {
        DataTypeValidation.assertDouble(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasDger(iNDArray.rows(), iNDArray.columns(), d, alloc, iNDArray.rows(), alloc2, iNDArray2.rows(), alloc3, iNDArray3.rows());
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data().asDouble()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    public static INDArray ger(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2, iNDArray3});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasSger(iNDArray.rows(), iNDArray.columns(), f, alloc, iNDArray.rows(), alloc2, iNDArray2.rows(), alloc3, iNDArray3.rows());
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data().asFloat()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    public static IComplexNDArray scal(IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        DataTypeValidation.assertFloat(iComplexNDArray);
        JCublas.cublasInit();
        Pointer alloc = alloc(jCublasComplexNDArray);
        JCublas.cublasCscal(iComplexNDArray.length(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), alloc, 1);
        getData(jCublasComplexNDArray, alloc, Pointer.to(jCublasComplexNDArray.data().asFloat()));
        free(alloc);
        return iComplexNDArray;
    }

    public static IComplexNDArray scal(IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        DataTypeValidation.assertDouble(iComplexNDArray);
        JCublas.cublasInit();
        Pointer alloc = alloc(jCublasComplexNDArray);
        JCublas.cublasZscal(iComplexNDArray.length(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), alloc, 1);
        getData(jCublasComplexNDArray, alloc, Pointer.to(jCublasComplexNDArray.data().asDouble()));
        free(alloc);
        return iComplexNDArray;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexDouble dotu(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        IComplexDouble createDouble;
        DataTypeValidation.assertSameDataType(new INDArray[]{iComplexNDArray, iComplexNDArray2});
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        if (iComplexNDArray.data().dataType() == 0) {
            cuDoubleComplex cublasZdotu = JCublas.cublasZdotu(iComplexNDArray.length(), alloc, 1, alloc2, 1);
            createDouble = Nd4j.createDouble(cublasZdotu.x, cublasZdotu.y);
        } else {
            cuComplex cublasCdotu = JCublas.cublasCdotu(iComplexNDArray.length(), alloc, 1, alloc2, 1);
            createDouble = Nd4j.createDouble(cublasCdotu.x, cublasCdotu.y);
        }
        free(alloc, alloc2);
        return createDouble;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray geru(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble) {
        DataTypeValidation.assertDouble(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray);
        JCublas.cublasZgeru(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray, alloc3, Pointer.to(jCublasComplexNDArray.data().asDouble()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gerc(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexFloat iComplexFloat) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray);
        JCublas.cublasCgerc(iComplexNDArray.rows(), iComplexNDArray.columns(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray, alloc3, Pointer.to(jCublasComplexNDArray.data().asFloat()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray geru(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexFloat iComplexFloat) {
        DataTypeValidation.assertFloat(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray);
        JCublas.cublasZgeru(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray, alloc3, Pointer.to(jCublasComplexNDArray.data().asFloat()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static IComplexNDArray gerc(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble) {
        DataTypeValidation.assertDouble(new INDArray[]{iComplexNDArray, iComplexNDArray2, iComplexNDArray3});
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray);
        JCublas.cublasZgerc(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray, alloc3, Pointer.to(jCublasComplexNDArray.data().asDouble()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    public static void axpy(double d, INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertDouble(new INDArray[]{iNDArray, iNDArray2});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc(jCublasNDArray);
        JCublas.cublasDaxpy(iNDArray.length(), d, alloc, 1, alloc2, 1);
        getData(jCublasNDArray, alloc2, Pointer.to(jCublasNDArray.data().asDouble()));
        free(alloc, alloc2);
    }

    public static void saxpy(float f, INDArray iNDArray, INDArray iNDArray2) {
        DataTypeValidation.assertFloat(new INDArray[]{iNDArray, iNDArray2});
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc(jCublasNDArray);
        JCublas.cublasSaxpy(iNDArray.length(), f, alloc, 1, alloc2, 1);
        getData(jCublasNDArray, alloc2, Pointer.to(jCublasNDArray.data().asFloat()));
        free(alloc, alloc2);
    }

    static {
        String str = "/" + resourceName().substring(3).replace("X", "x");
        ClassPathResource classPathResource = new ClassPathResource(str);
        if (!classPathResource.exists() && str.startsWith("/lib/")) {
            classPathResource = new ClassPathResource(str.replaceAll("/lib/", ""));
        } else if (!classPathResource.exists()) {
            classPathResource = new ClassPathResource(resourceName().replace("X", "x"));
        }
        if (!classPathResource.exists()) {
            throw new IllegalStateException("Unable to find resource with name " + classPathResource.getFilename());
        }
        log.info("Loading jcublas from " + classPathResource.getFilename());
        File file = new File(findWritableLibDir());
        File file2 = new File(file, resourceName().replace("X", "x"));
        try {
            if (file2.exists()) {
                file2.delete();
            }
            file2.createNewFile();
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file2));
            IOUtils.copy(classPathResource.getInputStream(), bufferedOutputStream);
            bufferedOutputStream.flush();
            bufferedOutputStream.close();
            file2.deleteOnExit();
            File file3 = new File(libDir());
            log.info("Loading cuda from " + file3.getAbsolutePath());
            Iterator iterateFiles = FileUtils.iterateFiles(file3, (String[]) null, false);
            while (iterateFiles.hasNext()) {
                File file4 = (File) iterateFiles.next();
                File file5 = new File(file, file4.getName());
                try {
                    FileUtils.copyFile(file4, file5);
                } catch (IOException e) {
                    e.printStackTrace();
                }
                file5.deleteOnExit();
            }
            JCublas.setLogLevel(LogLevel.LOG_DEBUG);
            JCublas.setExceptionsEnabled(true);
            JCublas.cublasInit();
            Runtime.getRuntime().addShutdownHook(new Thread() { // from class: org.nd4j.linalg.jcublas.SimpleJCublas.1
                @Override // java.lang.Thread, java.lang.Runnable
                public void run() {
                    JCublas.cublasShutdown();
                }
            });
        } catch (IOException e2) {
            throw new RuntimeException("Unable to initialize jcublas", e2);
        }
    }
}
