package com.omega.engine.gpu;

import com.omega.common.utils.CheckArrayUtils;
import com.omega.common.utils.Im2colToVector;
import com.omega.common.utils.Im2colUtils;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/gpu/ConvKernel.class */
public class ConvKernel {
    private String id;
    private float[] x;
    private float[] kernel;
    private float[] out;
    private int C;
    private int H;
    private int W;
    private int ko;
    private int kh;
    private int kw;
    private int s;
    private int p;
    private int oHeight;
    private int oWidth;
    private int ih;
    private int iw;
    private int numKernels;
    private CUfunction function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private CUdeviceptr dx;
    private CUdeviceptr dy;
    private Pointer dA;
    private Pointer dC;
    private Pointer kernelParameters;

    public ConvKernel(String str, float[] fArr, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
        this.id = str;
        this.C = i;
        this.H = i2;
        this.W = i3;
        this.ko = i4;
        this.kh = i5;
        this.kw = i6;
        this.s = i7;
        this.p = i8;
        this.oHeight = (((i2 + (2 * i8)) - i5) / i7) + 1;
        this.oWidth = (((i3 + (2 * i8)) - i6) / i7) + 1;
        this.out = fArr;
        this.ih = i * i5 * i6;
        this.iw = this.oHeight * this.oWidth;
        this.numKernels = i * this.oHeight * this.oWidth;
        init();
    }

    public void initFunction() {
        try {
            if (this.function == null) {
                this.function = CUDAModules.getLocalFunctionByModule("Im2colKernel.cu", "im2col_gpu_kernelV2");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void init() {
        initFunction();
        this.dx = CUDAMemoryManager.getDevice(this.C * this.H * this.W);
        if (this.kh == 1) {
            this.dy = this.dx;
        } else {
            this.dy = CUDAMemoryManager.getDevice(this.ih * this.iw);
        }
        this.dA = CUDAMemoryManager.getPointer(this.ko * this.ih);
        this.dC = CUDAMemoryManager.getPointer(this.ko * this.iw);
        this.kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.dx}), Pointer.to(new NativePointerObject[]{this.dy}), Pointer.to(new int[]{this.numKernels}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.kh}), Pointer.to(new int[]{this.kw}), Pointer.to(new int[]{this.s}), Pointer.to(new int[]{this.p}), Pointer.to(new int[]{this.oHeight}), Pointer.to(new int[]{this.oWidth})});
    }

    public void setX(float[] fArr) {
        this.x = fArr;
        JCudaDriver.cuMemcpyHtoD(this.dx, Pointer.to(fArr), fArr.length * 4);
    }

    public void setKernel(float[] fArr) {
        this.kernel = fArr;
        JCublas2.cublasSetVector(this.ko * this.ih, 4, Pointer.to(fArr), 1, this.dA, 1);
    }

    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void conv() {
        if (this.kh > 1) {
            im2col();
        }
        sgemm();
    }

    public void sgemm() {
        GPUOP.getInstance().multiplyFloat(this.ko, this.iw, this.ih, this.dA, (Pointer) this.dy, this.dC, 0, 0, 1.0f, 0.0f);
    }

    public void im2col() {
        try {
            JCudaDriver.cuLaunchKernel(this.function, CAFFE_GET_BLOCKS(this.numKernels), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelParameters, (Pointer) null);
            JCudaDriver.cuCtxSynchronize();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void free() {
        JCuda.cudaFree(this.dx);
        JCuda.cudaFree(this.dy);
        if (this.dA != null) {
            GPUOP.getInstance().free(this.dA);
            GPUOP.getInstance().free(this.dC);
        }
    }

    public float[] getOut() {
        JCublas2.cublasGetVector(this.out.length, 4, this.dC, 1, Pointer.to(this.out), 1);
        return this.out;
    }

    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public static void main(String[] strArr) {
        int i = (((8 + (2 * 0)) - 3) / 1) + 1;
        int i2 = (((8 + (2 * 0)) - 3) / 1) + 1;
        int i3 = i * i2;
        float[] order = RandomUtils.order(2 * 3 * 8 * 8, 0.1f, 0.1f);
        float[] order2 = RandomUtils.order(2 * 3 * 3 * 3, 0.1f, 0.1f);
        float[] fArr = new float[2 * i3];
        float[][][][] fArr2 = new float[2][2][i][i2];
        float[][][][] fArr3 = new float[2][2][i][i2];
        float[] fArr4 = new float[3 * 8 * 8];
        float[] fArr5 = new float[2 * 2 * i * i2];
        ConvKernel convKernel = new ConvKernel("conv1", fArr, 3, 8, 8, 2, 3, 3, 1, 0);
        convKernel.setKernel(order2);
        long nanoTime = System.nanoTime();
        for (int i4 = 0; i4 < 20; i4++) {
            long nanoTime2 = System.nanoTime();
            for (int i5 = 0; i5 < 2; i5++) {
                System.arraycopy(order, i5 * 3 * 8 * 8, fArr4, 0, 3 * 8 * 8);
                convKernel.setX(fArr4);
                convKernel.conv();
                System.arraycopy(convKernel.getOut(), 0, fArr5, i5 * 2 * i3, 2 * i3);
                MatrixUtils.col2im4d(convKernel.getOut(), fArr2, i5, 2, i, i2);
            }
            System.out.println(((System.nanoTime() - nanoTime2) / 1000000.0d) + "ms================>c.:" + i4);
        }
        System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
        System.out.println(JsonUtils.toJson(fArr5));
        float[] fArr6 = new float[2 * i * i2 * 3 * 3 * 3];
        float[][][][] transform = MatrixUtils.transform(order, 2, 3, 8, 8);
        float[][][][] transform2 = MatrixUtils.transform(order2, 2, 3, 3, 3);
        float[] kernalToVector = Im2colUtils.kernalToVector(transform2, false);
        System.out.println("k:" + CheckArrayUtils.check(order2, Im2colUtils.kernalToVector2(transform2, false)));
        float[] fArr7 = new float[2 * 2 * i3];
        long nanoTime3 = System.nanoTime();
        for (int i6 = 0; i6 < 20; i6++) {
            long nanoTime4 = System.nanoTime();
            Im2colToVector.im2col(transform, fArr6, 3, 3, 1);
            float[] fArr8 = new float[2 * 2 * i3];
            GPUOP.getInstance().multiplyFloat(2 * i * i2, 3 * 3 * 3, 2, fArr6, kernalToVector, fArr8);
            System.out.println(((System.nanoTime() - nanoTime4) / 1000000.0d) + "ms.cpu:" + i6);
            fArr7 = fArr8;
        }
        System.out.println(((System.nanoTime() - nanoTime3) / 1000000.0d) + "ms.cpu-count");
        MatrixUtils.col2imgV2(fArr7, fArr3, 2, 2, i, i2);
        System.out.println(CheckArrayUtils.check(fArr2, fArr3));
        CUDAMemoryManager.free();
    }
}
