package com.omega.engine.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.PrintUtils;
import com.omega.common.utils.RandomUtils;
import java.util.List;
import java.util.Locale;
import jcuda.CudaException;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.jcublas.cublasStatus;
import jcuda.jcurand.JCurand;
import jcuda.jcurand.curandGenerator;
import jcuda.jcurand.curandStatus;
import jcuda.runtime.JCuda;

/* loaded from: input_file:com/omega/engine/gpu/GPUOP.class */
public class GPUOP {
    private static GPUOP o;
    private cublasHandle handle = new cublasHandle();
    private curandGenerator generator;

    public GPUOP() {
        JCublas2.cublasCreate(this.handle);
    }

    public void init() {
        JCublas2.cublasCreate(this.handle);
    }

    public void clear() {
        JCublas2.cublasDestroy(this.handle);
    }

    public curandGenerator getGenerator() {
        if (this.generator == null) {
            this.generator = new curandGenerator();
            checkCURANDResult(JCurand.curandCreateGenerator(this.generator, 100));
        }
        return this.generator;
    }

    public void conv(float[] fArr, float[] fArr2, float[] fArr3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
        Pointer im2col = im2col(fArr, i, i2, i3, i4, i6, i7, i8);
        int i9 = i * (((i3 - i6) / i8) + 1) * (((i4 - i7) / i8) + 1);
        int i10 = i6 * i7 * i2;
        Pointer pointer = new Pointer();
        Pointer pointer2 = new Pointer();
        JCuda.cudaMalloc(pointer, fArr2.length * 4);
        JCuda.cudaMalloc(pointer2, fArr3.length * 4);
        JCublas2.cublasSetVector(fArr2.length, 4, Pointer.to(fArr2), 1, pointer, 1);
        JCublas2.cublasSgemm(this.handle, 0, 0, i5, i9, i10, Pointer.to(new float[]{1.0f}), pointer, i5, im2col, i10, Pointer.to(new float[]{0.0f}), pointer2, i5);
        JCublas2.cublasGetVector(i9 * i5, 4, pointer2, 1, Pointer.to(fArr3), 1);
        JCuda.cudaFree(im2col);
        JCuda.cudaFree(pointer);
        JCuda.cudaFree(pointer2);
    }

    private Pointer im2col(float[] fArr, int i, int i2, int i3, int i4, int i5, int i6, int i7) {
        CUfunction localFunctionByModule = CUDAModules.getLocalFunctionByModule("Im2colKernel.cu", "im2col_gpuv4");
        int i8 = ((i3 - i5) / i7) + 1;
        int i9 = ((i4 - i6) / i7) + 1;
        int i10 = i * i8 * i9;
        int i11 = i2 * i5 * i6;
        NativePointerObject cUdeviceptr = new CUdeviceptr();
        JCudaDriver.cuMemAlloc(cUdeviceptr, fArr.length * 4);
        JCudaDriver.cuMemcpyHtoD(cUdeviceptr, Pointer.to(fArr), fArr.length * 4);
        NativePointerObject cUdeviceptr2 = new CUdeviceptr();
        JCudaDriver.cuMemAlloc(cUdeviceptr2, i10 * i11 * 4);
        NativePointerObject nativePointerObject = Pointer.to(new NativePointerObject[]{cUdeviceptr});
        NativePointerObject nativePointerObject2 = Pointer.to(new NativePointerObject[]{cUdeviceptr2});
        Pointer pointer = Pointer.to(new NativePointerObject[]{nativePointerObject, nativePointerObject2, Pointer.to(new int[]{i}), Pointer.to(new int[]{i2}), Pointer.to(new int[]{i3}), Pointer.to(new int[]{i4}), Pointer.to(new int[]{i5}), Pointer.to(new int[]{i6}), Pointer.to(new int[]{i7}), Pointer.to(new int[]{i8}), Pointer.to(new int[]{i9}), Pointer.to(new int[]{i10}), Pointer.to(new int[]{i11}), Pointer.to(new int[]{i5 * i6})});
        int i12 = 1024;
        int i13 = (((i10 * i11) + 1024) - 1) / 1024;
        if (i10 * i11 <= 1024) {
            i12 = i10 * i11;
            i13 = 1;
        }
        JCudaDriver.cuLaunchKernel(localFunctionByModule, i13, 1, 1, i12, 1, 1, 0, (CUstream) null, pointer, (Pointer) null);
        JCuda.cudaDeviceSynchronize();
        JCuda.cudaFree(nativePointerObject);
        return nativePointerObject2;
    }

    public void multiplyFloat(int i, int i2, int i3, float[] fArr, float[] fArr2, float[] fArr3) {
        try {
            Pointer pointer = new Pointer();
            Pointer pointer2 = new Pointer();
            Pointer pointer3 = new Pointer();
            JCuda.cudaMalloc(pointer, i * i2 * 4);
            JCuda.cudaMalloc(pointer2, i2 * i3 * 4);
            JCuda.cudaMalloc(pointer3, i * i3 * 4);
            JCublas2.cublasSetVector(i * i2, 4, Pointer.to(fArr), 1, pointer, 1);
            JCublas2.cublasSetVector(i2 * i3, 4, Pointer.to(fArr2), 1, pointer2, 1);
            JCublas2.cublasSgemm(this.handle, 0, 0, i3, i, i2, Pointer.to(new float[]{1.0f}), pointer2, i3, pointer, i2, Pointer.to(new float[]{0.0f}), pointer3, i3);
            JCuda.cudaDeviceSynchronize();
            JCublas2.cublasGetVector(i * i3, 4, pointer3, 1, Pointer.to(fArr3), 1);
            JCuda.cudaFree(pointer);
            JCuda.cudaFree(pointer2);
            JCuda.cudaFree(pointer3);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void multiplyFloat(int i, int i2, int i3, float[] fArr, Pointer pointer, Pointer pointer2, Pointer pointer3) {
        try {
            JCublas2.cublasSgemm(this.handle, 0, 0, i3, i, i2, Pointer.to(new float[]{1.0f}), pointer2, i3, pointer, i2, Pointer.to(new float[]{0.0f}), pointer3, i3);
            JCublas2.cublasGetVector(i * i3, 4, pointer3, 1, Pointer.to(fArr), 1);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void multiplyFloat(int i, int i2, int i3, float[] fArr, Pointer pointer, Pointer pointer2, Pointer pointer3, int i4, int i5, float f, float f2) {
        try {
            Pointer pointer4 = Pointer.to(new float[]{f});
            JCublas2.cublasSgemm(this.handle, i5, i4, i2, i, i3, Pointer.to(new float[]{f2}), pointer2, i5 == 0 ? i2 : i3, pointer, i4 == 0 ? i3 : i, pointer4, pointer3, i2);
            JCublas2.cublasGetVector(fArr.length, 4, pointer3, 1, Pointer.to(fArr), 1);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void multiplyFloat(int i, int i2, int i3, Pointer pointer, Pointer pointer2, Pointer pointer3, int i4, int i5, float f, float f2) {
        try {
            Pointer pointer4 = Pointer.to(new float[]{f});
            Pointer pointer5 = Pointer.to(new float[]{f2});
            checkCUBLASResult(JCublas2.cublasSgemm(this.handle, i5, i4, i2, i, i3, pointer4, pointer2, i5 == 0 ? i2 : i3, pointer, i4 == 0 ? i3 : i, pointer5, pointer3, i2));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void gemv(int i, int i2, int i3, Tensor tensor, Tensor tensor2, Tensor tensor3, float f, float f2) {
        try {
            JCublas2.cublasSgemv(this.handle, i, i3, i2, Pointer.to(new float[]{f}), tensor.getGpuData(), i3, tensor2.getGpuData(), 1, Pointer.to(new float[]{f2}), tensor3.getGpuData(), 1);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void launchMultiply(Pointer pointer, Pointer pointer2, Pointer pointer3, int i, int i2, int i3, int i4, boolean z, boolean z2, int i5) {
        if (i5 == 0) {
            Pointer pointer4 = Pointer.to(new float[]{1.0f});
            Pointer pointer5 = Pointer.to(new float[]{0.0f});
            if (!z && !z2) {
                JCublas2.cublasSgemmStridedBatched(this.handle, 0, 0, i3, i2, i4, pointer4, pointer2, i3, i3 * i4, pointer, i4, i2 * i4, pointer5, pointer3, i3, i2 * i3, i);
                return;
            }
            if (!z && z2) {
                JCublas2.cublasSgemmStridedBatched(this.handle, 1, 0, i3, i2, i4, pointer4, pointer2, i4, i3 * i4, pointer, i4, i2 * i4, pointer5, pointer3, i3, i2 * i3, i);
                return;
            }
            if (z && !z2) {
                JCublas2.cublasSgemmStridedBatched(this.handle, 0, 1, i3, i2, i4, pointer4, pointer2, i3, i3 * i4, pointer, i2, i2 * i4, pointer5, pointer3, i3, i2 * i3, i);
                return;
            } else {
                if (z && z2) {
                    JCublas2.cublasSgemmStridedBatched(this.handle, 1, 1, i3, i2, i4, pointer4, pointer2, i4, i3 * i4, pointer, i2, i2 * i4, pointer5, pointer3, i3, i2 * i3, i);
                    return;
                }
                return;
            }
        }
        Pointer pointer6 = Pointer.to(new float[]{(float) Math.sqrt(1.0d / i5)});
        Pointer pointer7 = Pointer.to(new float[]{0.0f});
        if (!z && !z2) {
            JCublas2.cublasSgemmStridedBatched(this.handle, 0, 0, i3, i2, i4, pointer6, pointer2, i3, i3 * i4, pointer, i4, i2 * i4, pointer7, pointer3, i3, i2 * i3, i);
            return;
        }
        if (!z && z2) {
            JCublas2.cublasSgemmStridedBatched(this.handle, 1, 0, i3, i2, i4, pointer6, pointer2, i4, i3 * i4, pointer, i4, i2 * i4, pointer7, pointer3, i3, i2 * i3, i);
            return;
        }
        if (z && !z2) {
            JCublas2.cublasSgemmStridedBatched(this.handle, 0, 1, i3, i2, i4, pointer6, pointer2, i3, i3 * i4, pointer, i2, i2 * i4, pointer7, pointer3, i3, i2 * i3, i);
        } else if (z && z2) {
            JCublas2.cublasSgemmStridedBatched(this.handle, 1, 1, i3, i2, i4, pointer6, pointer2, i4, i3 * i4, pointer, i2, i2 * i4, pointer7, pointer3, i3, i2 * i3, i);
        }
    }

    public void bmm(Pointer pointer, Pointer pointer2, Pointer pointer3, int i, int i2, int i3, int i4, int i5, int i6) {
        checkCUBLASResult(JCublas2.cublasSgemmStridedBatched(this.handle, i6, i5, i3, i2, i4, Pointer.to(new float[]{1.0f}), pointer2, i6 == 0 ? i3 : i4, i3 * i4, pointer, i5 == 0 ? i4 : i2, i2 * i4, Pointer.to(new float[]{0.0f}), pointer3, i3, i2 * i3, i));
    }

    public void bmm(Pointer pointer, Pointer pointer2, Pointer pointer3, int i, int i2, int i3, int i4, int i5, int i6, float f, float f2) {
        checkCUBLASResult(JCublas2.cublasSgemmStridedBatched(this.handle, i6, i5, i3, i2, i4, Pointer.to(new float[]{f}), pointer2, i6 == 0 ? i3 : i4, i3 * i4, pointer, i5 == 0 ? i4 : i2, i2 * i4, Pointer.to(new float[]{f2}), pointer3, i3, i2 * i3, i));
    }

    public void bmm(int i, int i2, int i3, int i4, int i5, float f, Pointer pointer, int i6, long j, Pointer pointer2, int i7, long j2, float f2, Pointer pointer3, int i8, long j3, int i9) {
        checkCUBLASResult(JCublas2.cublasSgemmStridedBatched(this.handle, i, i2, i3, i4, i5, Pointer.to(new float[]{f}), pointer, i6, j, pointer2, i7, j2, Pointer.to(new float[]{f2}), pointer3, i8, j3, i9));
    }

    public void multiplyFloat(int i, int i2, int i3, float[] fArr, float[] fArr2, float[] fArr3, int i4, int i5, float f, float f2) {
        int i6 = i4 == 0 ? i3 : i;
        try {
            int i7 = i5 == 0 ? i2 : i3;
            Pointer pointer = new Pointer();
            Pointer pointer2 = new Pointer();
            Pointer pointer3 = new Pointer();
            JCuda.cudaMalloc(pointer, i * i3 * 4);
            JCuda.cudaMalloc(pointer2, i3 * i2 * 4);
            JCuda.cudaMalloc(pointer3, i * i2 * 4);
            JCublas2.cublasSetVector(i * i3, 4, Pointer.to(fArr), 1, pointer, 1);
            JCublas2.cublasSetVector(i3 * i2, 4, Pointer.to(fArr2), 1, pointer2, 1);
            JCublas2.cublasSgemm(this.handle, i5, i4, i2, i, i3, Pointer.to(new float[]{f}), pointer2, i7, pointer, i6, Pointer.to(new float[]{f2}), pointer3, i2);
            JCuda.cudaDeviceSynchronize();
            JCublas2.cublasGetVector(i * i2, 4, pointer3, 1, Pointer.to(fArr3), 1);
            JCuda.cudaFree(pointer);
            JCuda.cudaFree(pointer2);
            JCuda.cudaFree(pointer3);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void multiplyFloat(int i, int i2, int i3, float[] fArr, float[] fArr2, float[] fArr3, int i4, int i5, float f, float f2, int i6, int i7, int i8) {
        try {
            Pointer pointer = new Pointer();
            Pointer pointer2 = new Pointer();
            Pointer pointer3 = new Pointer();
            JCuda.cudaMalloc(pointer, i * i3 * 4);
            JCuda.cudaMalloc(pointer2, i3 * i2 * 4);
            JCuda.cudaMalloc(pointer3, i * i2 * 4);
            JCublas2.cublasSetVector(i * i3, 4, Pointer.to(fArr), 1, pointer, 1);
            JCublas2.cublasSetVector(i3 * i2, 4, Pointer.to(fArr2), 1, pointer2, 1);
            JCublas2.cublasSgemm(this.handle, i5, i4, i2, i, i3, Pointer.to(new float[]{f}), pointer2, i7, pointer, i6, Pointer.to(new float[]{f2}), pointer3, i8);
            JCuda.cudaDeviceSynchronize();
            JCublas2.cublasGetVector(i * i2, 4, pointer3, 1, Pointer.to(fArr3), 1);
            JCuda.cudaFree(pointer);
            JCuda.cudaFree(pointer2);
            JCuda.cudaFree(pointer3);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void gpu_gemv(int i, int i2, Pointer pointer, Pointer pointer2, Pointer pointer3, int i3, float f, float f2) {
        try {
            JCublas2.cublasSgemv(this.handle, i3, i2, i, Pointer.to(new float[]{f}), pointer, i2, pointer2, 1, Pointer.to(new float[]{f2}), pointer3, 1);
            JCuda.cudaDeviceSynchronize();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void multiplyFloatBatch(int i, int i2, int i3, int i4, float[] fArr, List<float[]> list, List<float[]> list2) {
        Pointer[] pointerArr = new Pointer[i4];
        Pointer[] pointerArr2 = new Pointer[i4];
        Pointer[] pointerArr3 = new Pointer[i4];
        for (int i5 = 0; i5 < i4; i5++) {
            pointerArr[i5] = new Pointer();
            JCuda.cudaMalloc(pointerArr[i5], i * i2 * 4);
            JCuda.cudaMemcpy(pointerArr[i5], Pointer.to(fArr), i * i2 * 4, 1);
            pointerArr2[i5] = new Pointer();
            JCuda.cudaMalloc(pointerArr2[i5], i2 * i3 * 4);
            JCuda.cudaMemcpy(pointerArr2[i5], Pointer.to(list.get(i5)), i2 * i3 * 4, 1);
            pointerArr3[i5] = new Pointer();
            JCuda.cudaMalloc(pointerArr3[i5], i * i3 * 4);
        }
        Pointer pointer = new Pointer();
        Pointer pointer2 = new Pointer();
        Pointer pointer3 = new Pointer();
        JCuda.cudaMalloc(pointer, i * i2 * Sizeof.POINTER);
        JCuda.cudaMalloc(pointer2, i2 * i3 * Sizeof.POINTER);
        JCuda.cudaMalloc(pointer3, i * i3 * Sizeof.POINTER);
        JCuda.cudaMemcpy(pointer, Pointer.to(pointerArr), i * i2 * Sizeof.POINTER, 1);
        JCuda.cudaMemcpy(pointer2, Pointer.to(pointerArr2), i2 * i3 * Sizeof.POINTER, 1);
        JCuda.cudaMemcpy(pointer3, Pointer.to(pointerArr3), i * i3 * Sizeof.POINTER, 1);
        JCublas2.cublasSgemmBatched(this.handle, 0, 0, i3, i, i2, Pointer.to(new float[]{1.0f}), pointer2, i3, pointer, i2, Pointer.to(new float[]{0.0f}), pointer3, i3, i4);
        JCuda.cudaDeviceSynchronize();
        for (int i6 = 0; i6 < i4; i6++) {
            JCuda.cudaMemcpy(Pointer.to(list2.get(i6)), pointerArr3[i6], i * i3 * 4, 2);
            JCuda.cudaFree(pointerArr[i6]);
            JCuda.cudaFree(pointerArr2[i6]);
            JCuda.cudaFree(pointerArr3[i6]);
        }
    }

    public void cudaRandom(Tensor tensor) {
        try {
            checkCURANDResult(JCurand.curandSetPseudoRandomGeneratorSeed(getGenerator(), RandomUtils.rand()));
            checkCURANDResult(JCurand.curandGenerateUniform(getGenerator(), tensor.getGpuData(), tensor.getDataLength()));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void free(Pointer pointer) {
        JCuda.cudaFree(pointer);
    }

    public void multiplyDouble(int i, int i2, int i3, double[] dArr, double[] dArr2, double[] dArr3) {
        JCublas2.cublasCreate(this.handle);
        Pointer pointer = new Pointer();
        Pointer pointer2 = new Pointer();
        Pointer pointer3 = new Pointer();
        JCuda.cudaMalloc(pointer, i * i2 * 8);
        JCuda.cudaMalloc(pointer2, i2 * i3 * 8);
        JCuda.cudaMalloc(pointer3, i * i3 * 8);
        JCublas2.cublasSetVector(i * i2, 8, Pointer.to(dArr), 1, pointer, 1);
        JCublas2.cublasSetVector(i2 * i3, 8, Pointer.to(dArr2), 1, pointer2, 1);
        JCublas2.cublasSetVector(i * i3, 8, Pointer.to(dArr3), 1, pointer3, 1);
        JCublas2.cublasSgemm(this.handle, 0, 0, i3, i, i2, Pointer.to(new float[]{1.0f}), pointer2, i3, pointer, i2, Pointer.to(new float[]{0.0f}), pointer3, i3);
        JCuda.cudaDeviceSynchronize();
        JCublas2.cublasGetVector(i * i3, 8, pointer3, 1, Pointer.to(dArr3), 1);
        JCuda.cudaFree(pointer);
        JCuda.cudaFree(pointer2);
        JCuda.cudaFree(pointer3);
        JCublas2.cublasDestroy(this.handle);
    }

    public static GPUOP getInstance() {
        if (o == null) {
            o = new GPUOP();
        }
        return o;
    }

    public static void main(String[] strArr) {
        testBatch();
    }

    public static void test() {
        float[] order = RandomUtils.order(5 * 4, 1.0f, 1.0f);
        float[] order2 = RandomUtils.order(1 * 3, 1.0f, 1.0f);
        Tensor tensor = new Tensor(5, 1, 1, 4, order, true);
        Tensor tensor2 = new Tensor(1, 1, 1, 3, order2, true);
        Tensor tensor3 = new Tensor(5, 1, 1, 3, true);
        for (int i = 0; i < 4; i++) {
            getInstance().multiplyFloat(5, 3, 4, tensor.getGpuData().withByteOffset(i * 4), tensor2.getGpuData(), tensor3.getGpuData(), 0, 0, 1.0f, 0.0f);
            PrintUtils.printImage(tensor3.syncHost());
            System.out.println("");
        }
    }

    public static void testBatch() {
        float[] order = RandomUtils.order(2 * 5 * 4, 1.0f, 0.0f);
        float[] order2 = RandomUtils.order(2 * 3 * 4, 1.0f, 0.0f);
        Tensor tensor = new Tensor(2, 1, 5, 4, order, true);
        Tensor tensor2 = new Tensor(2, 1, 3, 4, order2, true);
        Tensor tensor3 = new Tensor(2, 1, 5, 3, true);
        getInstance().bmm(tensor.getGpuData(), tensor2.getGpuData(), tensor3.getGpuData(), 2, 5, 3, 4, 0, 1);
        tensor.showDM();
        tensor3.showDM();
        System.out.println("");
    }

    public static String toString2D(float[] fArr, int i) {
        StringBuilder sb = new StringBuilder();
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (i2 > 0 && i2 % i == 0) {
                sb.append("\n");
            }
            sb.append(String.format(Locale.ENGLISH, "%7.4f ", Float.valueOf(fArr[i2])));
        }
        return sb.toString();
    }

    public static String toString2D(double[] dArr, int i) {
        StringBuilder sb = new StringBuilder();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (i2 > 0 && i2 % i == 0) {
                sb.append("\n");
            }
            sb.append(String.format(Locale.ENGLISH, "%7.4f ", Double.valueOf(dArr[i2])));
        }
        return sb.toString();
    }

    private static int checkCUBLASResult(int i) {
        if (i == 0) {
            return i;
        }
        System.err.println("cuda error code:" + i + "[" + cublasStatus.stringFor(i) + "]");
        throw new CudaException(cublasStatus.stringFor(i));
    }

    private static int checkCURANDResult(int i) {
        if (i == 0) {
            return i;
        }
        System.err.println("curand error code:" + i + "[" + curandStatus.stringFor(i) + "]");
        throw new CudaException(cublasStatus.stringFor(i));
    }
}
