package org.nd4j.linalg.jcublas;

import jcuda.LogLevel;
import jcuda.cuComplex;
import jcuda.cuDoubleComplex;
import jcuda.jcublas.JCublas;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.complex.JCublasComplexNDArray;

/* loaded from: input_file:org/nd4j/linalg/jcublas/SimpleJCublas.class */
public class SimpleJCublas {
    public static void alloc(JCublasComplexNDArray... jCublasComplexNDArrayArr) {
        for (JCublasComplexNDArray jCublasComplexNDArray : jCublasComplexNDArrayArr) {
            jCublasComplexNDArray.alloc();
        }
    }

    public static void free(JCublasComplexNDArray... jCublasComplexNDArrayArr) {
        for (JCublasComplexNDArray jCublasComplexNDArray : jCublasComplexNDArrayArr) {
            jCublasComplexNDArray.free();
        }
    }

    public static void alloc(JCublasNDArray... jCublasNDArrayArr) {
        for (JCublasNDArray jCublasNDArray : jCublasNDArrayArr) {
            jCublasNDArray.alloc();
        }
    }

    public static void free(JCublasNDArray... jCublasNDArrayArr) {
        for (JCublasNDArray jCublasNDArray : jCublasNDArrayArr) {
            jCublasNDArray.free();
        }
    }

    public static INDArray gemv(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f, float f2) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        JCublasNDArray jCublasNDArray3 = (JCublasNDArray) iNDArray3;
        alloc(jCublasNDArray, jCublasNDArray2, jCublasNDArray3);
        JCublas.cublasSgemv('N', iNDArray.rows(), iNDArray.columns(), f, jCublasNDArray.pointer(), iNDArray.rows(), jCublasNDArray2.pointer(), 1, f2, jCublasNDArray3.pointer(), 1);
        jCublasNDArray3.getData();
        free(jCublasNDArray, jCublasNDArray2, jCublasNDArray3);
        return iNDArray3;
    }

    public static IComplexNDArray gemm(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, float f, float f2) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iComplexNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iComplexNDArray2;
        JCublasNDArray jCublasNDArray3 = (JCublasNDArray) iComplexNDArray3;
        alloc(jCublasNDArray, jCublasNDArray2, jCublasNDArray3);
        JCublas.cublasCgemm('n', 'n', iComplexNDArray.rows(), iComplexNDArray2.columns(), iComplexNDArray2.rows(), cuComplex.cuCmplx(f, 0.0f), jCublasNDArray.pointer(), iComplexNDArray.rows(), jCublasNDArray2.pointer(), iComplexNDArray2.rows(), cuComplex.cuCmplx(f2, 0.0f), jCublasNDArray3.pointer(), iComplexNDArray3.rows());
        jCublasNDArray3.getData();
        free(jCublasNDArray, jCublasNDArray2, jCublasNDArray3);
        return iComplexNDArray3;
    }

    public static INDArray gemm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f, float f2) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        JCublasNDArray jCublasNDArray3 = (JCublasNDArray) iNDArray3;
        alloc(jCublasNDArray, jCublasNDArray2, jCublasNDArray3);
        JCublas.cublasSgemm('n', 'n', iNDArray.rows(), iNDArray2.columns(), iNDArray2.rows(), f, jCublasNDArray.pointer(), iNDArray.rows(), jCublasNDArray2.pointer(), iNDArray2.rows(), f2, jCublasNDArray3.pointer(), iNDArray3.rows());
        jCublasNDArray3.getData();
        free(jCublasNDArray, jCublasNDArray2, jCublasNDArray3);
        return iNDArray3;
    }

    public static void dcopy(int i, float[] fArr, int i2, int i3, float[] fArr2, int i4, int i5) {
    }

    public static float nrm2(IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        alloc(jCublasComplexNDArray);
        float cublasSnrm2 = JCublas.cublasSnrm2(iComplexNDArray.length(), jCublasComplexNDArray.pointer(), 2);
        free(jCublasComplexNDArray);
        return cublasSnrm2;
    }

    public static void copy(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        alloc(jCublasComplexNDArray, jCublasComplexNDArray2);
        JCublas.cublasScopy(iComplexNDArray.length(), jCublasComplexNDArray.pointer(), 1, jCublasComplexNDArray2.pointer(), 1);
        jCublasComplexNDArray2.getData();
        free(jCublasComplexNDArray, jCublasComplexNDArray2);
    }

    public static int iamax(IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        alloc(jCublasComplexNDArray);
        int cublasIzamax = JCublas.cublasIzamax(iComplexNDArray.length(), jCublasComplexNDArray.pointer(), iComplexNDArray.stride()[0]);
        free(jCublasComplexNDArray);
        return cublasIzamax;
    }

    public static float asum(IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        alloc(jCublasComplexNDArray);
        float cublasScasum = JCublas.cublasScasum(iComplexNDArray.length(), jCublasComplexNDArray.pointer(), 1);
        free(jCublasComplexNDArray);
        return cublasScasum;
    }

    public static int dznrm2(int i, float[] fArr, int i2, int i3) {
        return 0;
    }

    public static int dzasum(int i, float[] fArr, int i2, int i3) {
        return 0;
    }

    public static int izamax(int i, float[] fArr, int i2, int i3) {
        return 0;
    }

    public static void swap(INDArray iNDArray, INDArray iNDArray2) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        alloc(jCublasNDArray, jCublasNDArray2);
        JCublas.cublasSswap(jCublasNDArray.length(), jCublasNDArray.pointer(), 1, jCublasNDArray2.pointer(), 1);
        jCublasNDArray2.getData();
        free(jCublasNDArray, jCublasNDArray2);
    }

    public static float asum(INDArray iNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iNDArray;
        alloc(jCublasComplexNDArray);
        float cublasSasum = JCublas.cublasSasum(iNDArray.length(), jCublasComplexNDArray.pointer(), 1);
        free(jCublasComplexNDArray);
        return cublasSasum;
    }

    public static float nrm2(INDArray iNDArray) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        alloc(jCublasNDArray);
        float cublasSnrm2 = JCublas.cublasSnrm2(iNDArray.length(), jCublasNDArray.pointer(), 1);
        free(jCublasNDArray);
        return cublasSnrm2;
    }

    public static int iamax(INDArray iNDArray) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        alloc(jCublasNDArray);
        int cublasIdamax = JCublas.cublasIdamax(iNDArray.length(), jCublasNDArray.pointer(), iNDArray.stride()[0]);
        free(jCublasNDArray);
        return cublasIdamax;
    }

    public static void axpy(float f, INDArray iNDArray, INDArray iNDArray2) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        alloc(jCublasNDArray, jCublasNDArray2);
        float[] fArr = new float[jCublasNDArray.length()];
        float[] fArr2 = new float[jCublasNDArray2.length()];
        jCublasNDArray.getData(fArr);
        jCublasNDArray2.getData(fArr2);
        if (jCublasNDArray.ordering() == 'c') {
            JCublas.cublasSaxpy(jCublasNDArray.length(), f, jCublasNDArray.pointer(), jCublasNDArray.stride()[0], jCublasNDArray2.pointer(), 1);
            jCublasNDArray2.getData();
        } else {
            JCublas.cublasSaxpy(jCublasNDArray.length(), f, jCublasNDArray.pointer(), 1, jCublasNDArray2.pointer(), 1);
            jCublasNDArray2.getData();
        }
        free(jCublasNDArray, jCublasNDArray2);
    }

    public static void axpy(IComplexNumber iComplexNumber, IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        alloc(jCublasComplexNDArray, jCublasComplexNDArray2);
        JCublas.cublasCaxpy(jCublasComplexNDArray.length(), cuComplex.cuCmplx(iComplexNumber.realComponent().floatValue(), iComplexNumber.imaginaryComponent().floatValue()), jCublasComplexNDArray.pointer(), 1, jCublasComplexNDArray2.pointer(), 1);
        ((JCublasComplexNDArray) iComplexNDArray2).getData();
        free(jCublasComplexNDArray, jCublasComplexNDArray2);
    }

    public static INDArray scal(float f, INDArray iNDArray) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        alloc(jCublasNDArray);
        JCublas.cublasSscal(jCublasNDArray.length(), f, jCublasNDArray.pointer(), iNDArray.stride()[0]);
        jCublasNDArray.getData();
        free(jCublasNDArray);
        return iNDArray;
    }

    public static void copy(INDArray iNDArray, INDArray iNDArray2) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        alloc(jCublasNDArray, jCublasNDArray2);
        JCublas.cublasDcopy(iNDArray.length(), jCublasNDArray.pointer(), 1, jCublasNDArray2.pointer(), 1);
        ((JCublasNDArray) iNDArray2).getData();
        free(jCublasNDArray, jCublasNDArray2);
    }

    public static float dot(INDArray iNDArray, INDArray iNDArray2) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        alloc(jCublasNDArray, jCublasNDArray2);
        float cublasSdot = JCublas.cublasSdot(iNDArray.length(), jCublasNDArray.pointer(), 1, jCublasNDArray2.pointer(), 1);
        free(jCublasNDArray, jCublasNDArray2);
        return cublasSdot;
    }

    public static IComplexDouble dot(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        alloc(jCublasComplexNDArray, jCublasComplexNDArray2);
        cuDoubleComplex cublasZdotc = JCublas.cublasZdotc(iComplexNDArray.length(), jCublasComplexNDArray.pointer(), 1, jCublasComplexNDArray2.pointer(), 1);
        IComplexDouble createDouble = Nd4j.createDouble(cublasZdotc.x, cublasZdotc.y);
        free(jCublasComplexNDArray, jCublasComplexNDArray2);
        return createDouble;
    }

    public static INDArray ger(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        JCublasNDArray jCublasNDArray3 = (JCublasNDArray) iNDArray3;
        alloc(jCublasNDArray, jCublasNDArray2, jCublasNDArray3);
        JCublas.cublasSger(iNDArray.rows(), iNDArray.columns(), f, jCublasNDArray.pointer(), iNDArray.rows(), jCublasNDArray2.pointer(), iNDArray2.rows(), jCublasNDArray3.pointer(), iNDArray3.rows());
        jCublasNDArray3.getData();
        free(jCublasNDArray, jCublasNDArray2, jCublasNDArray3);
        return iNDArray3;
    }

    public static IComplexNDArray zscal(IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        alloc(jCublasComplexNDArray);
        JCublas.cublasCscal(iComplexNDArray.length(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), jCublasComplexNDArray.pointer(), 1);
        jCublasComplexNDArray.getData();
        free(jCublasComplexNDArray);
        return iComplexNDArray;
    }

    public static IComplexNDArray zscal(IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        alloc(jCublasComplexNDArray);
        JCublas.cublasZscal(iComplexNDArray.length(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), jCublasComplexNDArray.pointer(), 1);
        jCublasComplexNDArray.getData();
        free(jCublasComplexNDArray);
        return iComplexNDArray;
    }

    public static IComplexDouble dotu(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        alloc(jCublasComplexNDArray, jCublasComplexNDArray2);
        cuDoubleComplex cublasZdotu = JCublas.cublasZdotu(iComplexNDArray.length(), jCublasComplexNDArray.pointer(), iComplexNDArray.stride()[0], jCublasComplexNDArray2.pointer(), jCublasComplexNDArray2.stride()[0]);
        IComplexDouble createDouble = Nd4j.createDouble(cublasZdotu.x, cublasZdotu.y);
        free(jCublasComplexNDArray, jCublasComplexNDArray2);
        return createDouble;
    }

    public static IComplexNDArray geru(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        JCublasComplexNDArray jCublasComplexNDArray3 = (JCublasComplexNDArray) iComplexNDArray3;
        alloc(jCublasComplexNDArray, jCublasComplexNDArray2, jCublasComplexNDArray3);
        JCublas.cublasZgeru(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), jCublasComplexNDArray.pointer(), iComplexNDArray.rows(), jCublasComplexNDArray2.pointer(), iComplexNDArray2.rows(), jCublasComplexNDArray3.pointer(), iComplexNDArray3.rows());
        jCublasComplexNDArray3.getData();
        free(jCublasComplexNDArray, jCublasComplexNDArray2, jCublasComplexNDArray3);
        return iComplexNDArray3;
    }

    public static IComplexNDArray gerc(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        JCublasComplexNDArray jCublasComplexNDArray3 = (JCublasComplexNDArray) iComplexNDArray3;
        alloc(jCublasComplexNDArray, jCublasComplexNDArray2, jCublasComplexNDArray3);
        JCublas.cublasZgerc(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), jCublasComplexNDArray.pointer(), iComplexNDArray.rows(), jCublasComplexNDArray2.pointer(), iComplexNDArray2.rows(), jCublasComplexNDArray3.pointer(), iComplexNDArray3.rows());
        jCublasComplexNDArray3.getData();
        free(jCublasComplexNDArray, jCublasComplexNDArray2, jCublasComplexNDArray3);
        return iComplexNDArray3;
    }

    static {
        JCublas.setLogLevel(LogLevel.LOG_DEBUG);
        JCublas.setExceptionsEnabled(true);
        JCublas.cublasInit();
        Runtime.getRuntime().addShutdownHook(new Thread() { // from class: org.nd4j.linalg.jcublas.SimpleJCublas.1
            @Override // java.lang.Thread, java.lang.Runnable
            public void run() {
                JCublas.cublasShutdown();
            }
        });
    }
}
