package com.omega.engine.nn.layer.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.pooling.PoolingType;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/nn/layer/gpu/PoolingKernel.class */
public class PoolingKernel extends PoolingBaseKernel {
    private PoolingType type;
    private int C;
    private int H;
    private int W;
    private int ph;
    private int pw;
    private int s;
    private int padding;
    private int oHeight;
    private int oWidth;
    private int numKernels;
    private int max_f_n;
    private int max_b_n;
    private CUfunction forward_function;
    private CUfunction backward_function;
    private int CAFFE_CUDA_NUM_THREADS;
    private Pointer dm;
    private Pointer forwardKernelParameters;
    private Pointer backwardKernelParameters;

    public PoolingKernel(PoolingType poolingType, int i, int i2, int i3, int i4, int i5, int i6) {
        this.padding = 0;
        this.max_f_n = 0;
        this.max_b_n = 0;
        this.CAFFE_CUDA_NUM_THREADS = 1024;
        this.type = poolingType;
        this.C = i;
        this.H = i2;
        this.W = i3;
        this.ph = i4;
        this.pw = i5;
        this.s = i6;
        this.oHeight = ((i2 - i4) / i6) + 1;
        this.oWidth = ((i3 - i5) / i6) + 1;
        this.numKernels = 0;
        init();
    }

    public PoolingKernel(PoolingType poolingType, int i, int i2, int i3, int i4, int i5, int i6, int i7) {
        this.padding = 0;
        this.max_f_n = 0;
        this.max_b_n = 0;
        this.CAFFE_CUDA_NUM_THREADS = 1024;
        this.type = poolingType;
        this.C = i;
        this.H = i2;
        this.W = i3;
        this.ph = i4;
        this.pw = i5;
        this.s = i6;
        this.padding = i7;
        this.oHeight = (((i2 + i7) - i4) / i6) + 1;
        this.oWidth = (((i3 + i7) - i5) / i6) + 1;
        this.numKernels = 0;
        init();
    }

    public void initFunction() {
        try {
            if (this.forward_function == null) {
                switch (this.type) {
                    case MAX_POOLING:
                        this.forward_function = CUDAModules.getLocalFunctionByModule("PoolingV2Kernel.cu", "maxpool_forward");
                        break;
                    case MEAN_POOLING:
                        this.forward_function = CUDAModules.getLocalFunctionByModule("PoolingV2Kernel.cu", "meanpool_forward");
                        break;
                    case AVG_POOLING:
                        this.forward_function = CUDAModules.getLocalFunctionByModule("PoolingV2Kernel.cu", "avgpool_forward");
                        break;
                }
            }
            if (this.backward_function == null) {
                switch (this.type) {
                    case MAX_POOLING:
                        this.backward_function = CUDAModules.getLocalFunctionByModule("PoolingV2Kernel.cu", "maxpool_backward");
                        break;
                    case MEAN_POOLING:
                        this.backward_function = CUDAModules.getLocalFunctionByModule("PoolingV2Kernel.cu", "meanpool_backward");
                        break;
                    case AVG_POOLING:
                        this.backward_function = CUDAModules.getLocalFunctionByModule("PoolingV2Kernel.cu", "avgpool_backward");
                        break;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void init() {
        initFunction();
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    @Override // com.omega.engine.nn.layer.gpu.PoolingBaseKernel
    public void forward(Tensor tensor, Tensor tensor2) {
        tensor2.clearGPU();
        switch (this.type) {
            case MAX_POOLING:
                maxpooling(tensor, tensor2);
                return;
            case MEAN_POOLING:
                meanpooling(tensor, tensor2);
                return;
            default:
                return;
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.PoolingBaseKernel
    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        tensor4.clearGPU();
        switch (this.type) {
            case MAX_POOLING:
                maxpoolingDiff(tensor3, tensor4);
                return;
            case MEAN_POOLING:
                meanpoolingDiff(tensor3, tensor4);
                return;
            default:
                return;
        }
    }

    public void maxpooling(Tensor tensor, Tensor tensor2) {
        try {
            if (this.dm == null || tensor.number != this.N) {
                this.dm = CUDAMemoryManager.getPointer(tensor.number * this.C * this.oHeight * this.oWidth, 4);
                this.N = tensor.number;
                this.max_f_n = this.oHeight * this.oWidth * this.C * this.N;
                this.forwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.max_f_n}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.s}), Pointer.to(new int[]{this.ph}), Pointer.to(new int[]{this.padding}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{this.dm})});
            }
            JCudaDriver.cuLaunchKernel(this.forward_function, CAFFE_GET_BLOCKS(this.max_f_n), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void meanpooling(Tensor tensor, Tensor tensor2) {
        try {
            if (tensor.number != this.N) {
                this.N = tensor.number;
                this.max_f_n = this.oHeight * this.oWidth * this.C * this.N;
                this.forwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.max_f_n}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.s}), Pointer.to(new int[]{this.ph}), Pointer.to(new int[]{this.padding}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.forward_function, CAFFE_GET_BLOCKS(this.max_f_n), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void avgpooling(Tensor tensor, Tensor tensor2) {
        try {
            if (tensor.number != this.N) {
                this.N = tensor.number;
                this.numKernels = this.N * tensor.channel;
                this.forwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.numKernels}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.C}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.forward_function, CAFFE_GET_BLOCKS(this.numKernels), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void maxpoolingDiff(Tensor tensor, Tensor tensor2) {
        try {
            if (this.backwardKernelParameters == null) {
                this.max_b_n = this.H * this.W * this.C * this.N;
                this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.max_b_n}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.s}), Pointer.to(new int[]{this.ph}), Pointer.to(new int[]{this.padding}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{this.dm})});
            }
            JCudaDriver.cuLaunchKernel(this.backward_function, CAFFE_GET_BLOCKS(this.max_b_n), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void meanpoolingDiff(Tensor tensor, Tensor tensor2) {
        try {
            if (this.backwardKernelParameters == null) {
                this.max_b_n = this.H * this.W * this.C * this.N;
                this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.max_b_n}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.s}), Pointer.to(new int[]{this.ph}), Pointer.to(new int[]{this.padding}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.backward_function, CAFFE_GET_BLOCKS(this.max_b_n), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void avgpoolingDiff(Tensor tensor, Tensor tensor2) {
        try {
            if (this.backwardKernelParameters == null) {
                this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.numKernels}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.C}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.backward_function, CAFFE_GET_BLOCKS(this.numKernels), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void poolingDiff(Tensor tensor, Tensor tensor2, int i) {
        try {
            this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * this.C * this.oHeight * this.oWidth * 4)}), Pointer.to(new NativePointerObject[]{this.dm.withByteOffset(i * this.C * this.oHeight * this.oWidth * this.ph * this.pw * 4)}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData().withByteOffset(i * this.C * this.H * this.W * 4)}), Pointer.to(new int[]{this.numKernels}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W}), Pointer.to(new int[]{this.oHeight}), Pointer.to(new int[]{this.oWidth}), Pointer.to(new int[]{this.ph}), Pointer.to(new int[]{this.pw}), Pointer.to(new int[]{this.s})});
            JCudaDriver.cuLaunchKernel(this.backward_function, CAFFE_GET_BLOCKS(this.numKernels), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public static void main(String[] strArr) {
        CUDAModules.initContext();
        int i = ((4 - 4) / 4) + 1;
        int i2 = ((4 - 4) / 4) + 1;
        float[] order = MatrixUtils.order(2 * 3 * 4 * 4, 1, 1);
        float[] order2 = RandomUtils.order(2 * 3 * i * i2, 0.1f, 0.1f);
        Tensor tensor = new Tensor(2, 3, 4, 4, order, true);
        Tensor tensor2 = new Tensor(2, 3, i, i2, true);
        Tensor tensor3 = new Tensor(2, 3, i, i2, order2, true);
        Tensor tensor4 = new Tensor(2, 3, 4, 4, true);
        PoolingKernel poolingKernel = new PoolingKernel(PoolingType.MEAN_POOLING, 3, 4, 4, 4, 4, 4);
        long nanoTime = System.nanoTime();
        for (int i3 = 0; i3 < 2; i3++) {
            poolingKernel.forward(tensor, tensor2);
        }
        System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
        tensor2.showDM();
        tensor.showDM();
        poolingKernel.backward(tensor, tensor2, tensor3, tensor4);
        tensor3.showDM();
        tensor4.showDM();
    }
}
