package com.omega.engine.gpu;

import com.omega.common.utils.CheckArrayUtils;
import com.omega.common.utils.Im2colToVector;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;

/* loaded from: input_file:com/omega/engine/gpu/Im2colKernelStream.class */
public class Im2colKernelStream {
    private float[] x;
    private float[] out;
    private int N;
    private int C;
    private int H;
    private int W;
    private int kh;
    private int kw;
    private int s;
    private int oHeight;
    private int oWidth;
    private int ow;
    private int oh;
    private int kSize;
    private CUfunction function;

    public Im2colKernelStream(float[] fArr, float[] fArr2, int i, int i2, int i3, int i4, int i5, int i6, int i7) {
        this.x = fArr;
        this.N = i;
        this.C = i2;
        this.H = i3;
        this.W = i4;
        this.kh = i5;
        this.kw = i6;
        this.s = i7;
        this.oHeight = ((i3 - i5) / i7) + 1;
        this.oWidth = ((i4 - i6) / i7) + 1;
        this.oh = i * this.oHeight * this.oWidth;
        this.ow = i2 * i5 * i6;
        this.kSize = i5 * i6;
        this.out = fArr2;
        initFunction();
    }

    public void initFunction() {
        try {
            if (this.function == null) {
                this.function = CUDAModules.getLocalFunctionByModule("Im2colKernelTmp.cu", "im2col_gpuv4");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void im2col() {
        try {
            CUstream cUstream = new CUstream();
            JCudaDriver.cuStreamCreate(cUstream, 1);
            NativePointerObject cUdeviceptr = new CUdeviceptr();
            JCudaDriver.cuMemAlloc(cUdeviceptr, this.x.length * 4);
            NativePointerObject cUdeviceptr2 = new CUdeviceptr();
            JCudaDriver.cuMemAlloc(cUdeviceptr2, this.out.length * 4);
            Pointer pointer = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{cUdeviceptr}), Pointer.to(new NativePointerObject[]{cUdeviceptr2}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.kh}), Pointer.to(new int[]{this.kw}), Pointer.to(new int[]{this.s}), Pointer.to(new int[]{this.oHeight}), Pointer.to(new int[]{this.oWidth}), Pointer.to(new int[]{this.oh}), Pointer.to(new int[]{this.ow}), Pointer.to(new int[]{this.kSize})});
            int i = 1024;
            int i2 = (((this.oh * this.ow) + 1024) - 1) / 1024;
            if (this.oh * this.ow <= 1024) {
                i = this.oh * this.ow;
                i2 = 1;
            }
            for (int i3 = 0; i3 < 1; i3++) {
                JCudaDriver.cuMemcpyHtoDAsync(cUdeviceptr, Pointer.to(this.x), this.x.length * 4, cUstream);
                JCudaDriver.cuLaunchKernel(this.function, i2, 1, 1, i, 1, 1, 0, cUstream, pointer, (Pointer) null);
                JCudaDriver.cuMemcpyDtoHAsync(Pointer.to(this.out), cUdeviceptr2, this.out.length * 4, cUstream);
            }
            JCudaDriver.cuStreamSynchronize(cUstream);
            JCuda.cudaFree(cUdeviceptr);
            JCuda.cudaFree(cUdeviceptr2);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        int i = 64 * 3 * 3;
        int i2 = 128 * (((64 - 3) / 1) + 1) * (((64 - 3) / 1) + 1);
        float[] gaussianRandom = RandomUtils.gaussianRandom(128 * 64 * 64 * 64, 0.1f);
        float[][][][] transform = MatrixUtils.transform(gaussianRandom, 128, 64, 64, 64);
        float[] fArr = new float[i2 * i];
        for (int i3 = 0; i3 < 10; i3++) {
            long nanoTime = System.nanoTime();
            new Im2colKernelStream(gaussianRandom, fArr, 128, 64, 64, 64, 3, 3, 1).im2col();
            System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
        }
        System.out.println("==============================>");
        float[] fArr2 = new float[i2 * i];
        for (int i4 = 0; i4 < 10; i4++) {
            long nanoTime2 = System.nanoTime();
            Im2colToVector.im2col(transform, fArr2, 3, 3, 1);
            System.out.println(((System.nanoTime() - nanoTime2) / 1000000.0d) + "ms");
        }
        System.out.println(CheckArrayUtils.check(fArr, fArr2));
    }

    public float[] getOut() {
        return this.out;
    }
}
