package com.omega.engine.nn.layer.gpu;

import com.omega.common.data.Tensor;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.gpu.CUDAModules;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;

/* loaded from: input_file:com/omega/engine/nn/layer/gpu/RNNKernel.class */
public class RNNKernel extends BaseKernel {
    private CUfunction bias_function;
    private CUfunction output_function;
    private CUfunction output_bias_function;
    private CUfunction back_function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private Pointer biasKernelParameters;
    private Pointer outputKernelParameters;
    private Pointer outputBiasKernelParameters;
    private Pointer kernelBackParameters;

    public RNNKernel() {
        init();
    }

    public void init() {
        initFunction();
    }

    public void initFunction() {
        try {
            if (this.bias_function == null) {
                this.bias_function = CUDAModules.getLocalFunctionByModule("RNNKernel.cu", "add_bias");
            }
            if (this.output_function == null) {
                this.output_function = CUDAModules.getLocalFunctionByModule("RNNKernel.cu", "add_output");
            }
            if (this.output_bias_function == null) {
                this.output_bias_function = CUDAModules.getLocalFunctionByModule("RNNKernel.cu", "add_output_bias");
            }
            if (this.back_function == null) {
                this.back_function = CUDAModules.getLocalFunctionByModule("RNNKernel.cu", "backward_bias_conn_kernel");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void addBias(Tensor tensor, Tensor tensor2) {
        try {
            if (this.biasKernelParameters == null || tensor.number != this.N) {
                this.biasKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor.getNumber()}), Pointer.to(new int[]{tensor.getWidth()}), Pointer.to(new int[]{1})});
                this.N = tensor.number;
            }
            JCudaDriver.cuLaunchKernel(this.bias_function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.biasKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void addBias(Tensor tensor, Tensor tensor2, int i) {
        try {
            if (this.biasKernelParameters == null || tensor.number != this.N) {
                this.biasKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * 4)}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor.getNumber()}), Pointer.to(new int[]{tensor.getWidth()}), Pointer.to(new int[]{1})});
                this.N = tensor.number;
            }
            JCudaDriver.cuLaunchKernel(this.bias_function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.biasKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void addOutputBias(Tensor tensor, Tensor tensor2, Tensor tensor3, int i) {
        try {
            if (this.outputBiasKernelParameters == null || tensor.number != this.N) {
                this.outputBiasKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * 4)}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData().withByteOffset(i * 4)}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.getNumber()}), Pointer.to(new int[]{tensor.getWidth()}), Pointer.to(new int[]{1})});
                this.N = tensor.number;
            }
            JCudaDriver.cuLaunchKernel(this.output_bias_function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.outputBiasKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void addOutput(Tensor tensor, Tensor tensor2, int i) {
        try {
            if (this.outputKernelParameters == null || tensor.number != this.N) {
                this.outputKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * 4)}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData().withByteOffset(i * 4)}), Pointer.to(new int[]{tensor.getNumber()}), Pointer.to(new int[]{tensor.getWidth()}), Pointer.to(new int[]{1})});
                this.N = tensor.number;
            }
            JCudaDriver.cuLaunchKernel(this.output_function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.outputKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backwardBias(Tensor tensor, Tensor tensor2) {
        try {
            tensor.clearGPU();
            if (this.kernelBackParameters == null) {
                this.kernelBackParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor2.getNumber()}), Pointer.to(new int[]{tensor2.getWidth()})});
            }
            JCudaDriver.cuLaunchKernel(this.back_function, CAFFE_GET_BLOCKS(tensor2.getWidth()), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelBackParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void backwardBias(Pointer pointer, Pointer pointer2, int i, int i2) {
        try {
            this.kernelBackParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{pointer}), Pointer.to(new NativePointerObject[]{pointer2}), Pointer.to(new int[]{i}), Pointer.to(new int[]{i2})});
            JCudaDriver.cuLaunchKernel(this.back_function, CAFFE_GET_BLOCKS(i2), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelBackParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
