package com.omega.engine.nn.layer.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.gpu.GPUOP;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/nn/layer/gpu/ConvKernel.class */
public class ConvKernel extends ConvBaseKernel {
    private int C;
    private int H;
    private int W;
    private int ko;
    private int kh;
    private int kw;
    private int s;
    private int p;
    private int oHeight;
    private int oWidth;
    private int ih;
    private int iw;
    private boolean is_1x1;
    private int numKernels;
    private CUfunction im2col_function;
    private CUfunction bias_function;
    private CUfunction back_back_function;
    private CUfunction col2im_function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private Pointer kernelParameters;
    private Pointer biasKernelParameters;
    private Pointer biasBackKernelParameters;
    private Pointer col2imKernelParameters;
    private Pointer dy;
    private Pointer dx_t;

    public ConvKernel(int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
        this.is_1x1 = false;
        this.C = i;
        this.H = i2;
        this.W = i3;
        this.ko = i4;
        this.kh = i5;
        this.kw = i6;
        this.s = i7;
        this.p = i8;
        this.oHeight = (((i2 + (2 * i8)) - i5) / i7) + 1;
        this.oWidth = (((i3 + (2 * i8)) - i6) / i7) + 1;
        this.ih = i * i5 * i6;
        this.iw = this.oHeight * this.oWidth;
        this.numKernels = i * this.oHeight * this.oWidth;
        if (i5 == 1 && i6 == 1 && i7 == 1 && i8 == 0) {
            this.is_1x1 = true;
        }
        init();
    }

    public void init() {
        initFunction();
        if (this.is_1x1) {
            return;
        }
        this.dy = CUDAMemoryManager.getPointer(this.ih * this.iw);
        this.dx_t = CUDAMemoryManager.getPointer(this.ih * this.iw);
    }

    public void initFunction() {
        try {
            if (this.im2col_function == null) {
                this.im2col_function = CUDAModules.getLocalFunctionByModule("Im2colKernel.cu", "im2col_gpu_kernelV2");
            }
            if (this.bias_function == null) {
                this.bias_function = CUDAModules.getLocalFunctionByModule("BiasKernel.cu", "add_bias");
            }
            if (this.back_back_function == null) {
                this.back_back_function = CUDAModules.getLocalFunctionByModule("BiasKernel.cu", "backward_bias_kernel");
            }
            if (this.col2im_function == null) {
                this.col2im_function = CUDAModules.getLocalFunctionByModule("Col2imKernel.cu", "col2im_gpu_kernelV2");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.ConvBaseKernel
    public void conv(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        for (int i = 0; i < tensor.number; i++) {
            if (this.is_1x1) {
                this.dy = tensor.getGpuData().withByteOffset(i * this.C * this.H * this.W * 4);
            } else {
                im2col(tensor, i);
            }
            sgemm(tensor2, tensor3, i);
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.ConvBaseKernel
    public void dw(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        tensor3.clearGPU();
        for (int i = 0; i < tensor.number; i++) {
            if (this.is_1x1) {
                this.dy = tensor.getGpuData().withByteOffset(i * this.C * this.H * this.W * 4);
            } else {
                im2col(tensor, i);
            }
            sgemmDW(tensor2, tensor3, i);
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.ConvBaseKernel
    public void dx(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        for (int i = 0; i < tensor.number; i++) {
            if (this.is_1x1) {
                sgemmDX(tensor, tensor2, tensor3.getGpuData().withByteOffset(i * this.ih * this.iw * 4), i);
            } else {
                sgemmDX(tensor, tensor2, this.dx_t, i);
                col2im(tensor3, i);
            }
        }
    }

    public void im2col(Tensor tensor, int i) {
        try {
            this.kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * this.C * this.H * this.W * 4)}), Pointer.to(new NativePointerObject[]{this.dy}), Pointer.to(new int[]{this.numKernels}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.kh}), Pointer.to(new int[]{this.kw}), Pointer.to(new int[]{this.s}), Pointer.to(new int[]{this.p}), Pointer.to(new int[]{this.oHeight}), Pointer.to(new int[]{this.oWidth})});
            JCudaDriver.cuLaunchKernel(this.im2col_function, CAFFE_GET_BLOCKS(this.numKernels), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void col2im(Tensor tensor, int i) {
        try {
            this.col2imKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.dx_t}), Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * this.C * this.H * this.W * 4)}), Pointer.to(new int[]{this.C * this.H * this.W}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.kh}), Pointer.to(new int[]{this.kw}), Pointer.to(new int[]{this.p}), Pointer.to(new int[]{this.s}), Pointer.to(new int[]{this.oHeight}), Pointer.to(new int[]{this.oWidth})});
            checkCUDA(JCudaDriver.cuLaunchKernel(this.col2im_function, CAFFE_GET_BLOCKS(this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.col2imKernelParameters, (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void sgemm(Tensor tensor, Tensor tensor2, int i) {
        GPUOP.getInstance().multiplyFloat(this.ko, this.iw, this.ih, tensor.getGpuData(), this.dy, tensor2.getGpuData().withByteOffset(i * this.ko * this.oHeight * this.oWidth * 4), 0, 0, 1.0f, 0.0f);
    }

    public void sgemmDW(Tensor tensor, Tensor tensor2, int i) {
        GPUOP.getInstance().multiplyFloat(this.ko, this.ih, this.iw, tensor.getGpuData().withByteOffset(i * this.ko * this.iw * 4), this.dy, tensor2.getGpuData(), 0, 1, 1.0f, 1.0f);
    }

    public void sgemmDX(Tensor tensor, Tensor tensor2, Pointer pointer, int i) {
        GPUOP.getInstance().multiplyFloat(this.ih, this.iw, this.ko, tensor2.getGpuData(), tensor.getGpuData().withByteOffset(i * this.ko * this.iw * 4), pointer, 1, 0, 1.0f, 0.0f);
    }

    public void addBias(Tensor tensor, Tensor tensor2) {
        try {
            if (this.biasKernelParameters == null || tensor.number != this.N) {
                this.biasKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor.getNumber()}), Pointer.to(new int[]{tensor.channel}), Pointer.to(new int[]{tensor.height * tensor.width})});
                this.N = tensor.number;
            }
            JCudaDriver.cuLaunchKernel(this.bias_function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.biasKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backwardBias(Tensor tensor, Tensor tensor2) {
        try {
            tensor.clearGPU();
            if (this.biasBackKernelParameters == null) {
                this.biasBackKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor2.getNumber()}), Pointer.to(new int[]{tensor2.getChannel()}), Pointer.to(new int[]{tensor2.height * tensor2.width})});
            }
            JCudaDriver.cuLaunchKernel(this.back_back_function, tensor2.getChannel(), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.biasBackKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void showDM(Pointer pointer, float[] fArr) {
        JCuda.cudaMemcpy(Pointer.to(fArr), pointer, fArr.length * 4, 2);
        System.out.println(JsonUtils.toJson(fArr));
    }

    public static void main(String[] strArr) {
        CUDAModules.initContext();
        int i = (((8 + (2 * 0)) - 1) / 2) + 1;
        int i2 = (((8 + (2 * 0)) - 1) / 2) + 1;
        float[] order = RandomUtils.order(2 * 64 * 8 * 8, 0.1f, 0.1f);
        float[] order2 = RandomUtils.order(128 * 64 * 1 * 1, 0.1f, 0.1f);
        float[] order3 = RandomUtils.order(2 * 128 * i * i2, 0.1f, 0.1f);
        Tensor tensor = new Tensor(2, 64, 8, 8, order, true);
        Tensor tensor2 = new Tensor(2, 128, i, i2, order3, true);
        Tensor tensor3 = new Tensor(128, 64, 1, 1, order2, true);
        Tensor tensor4 = new Tensor(2, 128, i, i2, true);
        Tensor tensor5 = new Tensor(2, 64, 8, 8, true);
        Tensor tensor6 = new Tensor(128, 64, 1, 1, true);
        ConvKernel convKernel = new ConvKernel(64, 8, 8, 128, 1, 1, 2, 0);
        convKernel.conv(tensor, tensor3, tensor4);
        tensor4.syncHost();
        System.out.println("output:" + JsonUtils.toJson(tensor4.data));
        convKernel.dw(tensor, tensor2, tensor6);
        tensor6.syncHost();
        System.out.println("diffW:" + JsonUtils.toJson(tensor6.data));
        convKernel.dx(tensor2, tensor3, tensor5);
        tensor5.syncHost();
        System.out.println("diff:" + JsonUtils.toJson(tensor5.data));
    }

    @Override // com.omega.engine.nn.layer.gpu.ConvBaseKernel
    public void convTranspose(Tensor tensor, Tensor tensor2, Tensor tensor3) {
    }
}
