package com.omega.engine.nn.layer.active.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;

/* loaded from: input_file:com/omega/engine/nn/layer/active/gpu/TanhKernel.class */
public class TanhKernel extends BaseKernel {
    private CUfunction function;
    private CUfunction function_back;
    private CUfunction function_back_temp;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private Pointer forwardKernelParameters;
    private Pointer backwardKernelParameters;

    public TanhKernel() {
        init();
    }

    public void init() {
        initFunction();
    }

    public void initFunction() {
        try {
            if (this.function == null) {
                this.function = CUDAModules.getLocalFunctionByModule("activeFunction.cu", "tanh_forward");
            }
            if (this.function_back == null) {
                this.function_back = CUDAModules.getLocalFunctionByModule("activeFunction.cu", "tanh_backward");
            }
            if (this.function_back_temp == null) {
                this.function_back_temp = CUDAModules.getLocalFunctionByModule("activeFunction.cu", "tanh_backward_temp");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void forward(Tensor tensor, Tensor tensor2, int i, int i2) {
        try {
            this.forwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * 4)}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData().withByteOffset(i * 4)}), Pointer.to(new int[]{i2})});
            JCudaDriver.cuLaunchKernel(this.function, CAFFE_GET_BLOCKS(i2), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void forward(Pointer pointer, Pointer pointer2, int i) {
        try {
            this.forwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{pointer}), Pointer.to(new NativePointerObject[]{pointer2}), Pointer.to(new int[]{i})});
            JCudaDriver.cuLaunchKernel(this.function, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void forward(Tensor tensor, Tensor tensor2) {
        try {
            this.forwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor2.dataLength})});
            this.N = tensor2.number;
            JCudaDriver.cuLaunchKernel(this.function, CAFFE_GET_BLOCKS(tensor2.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        try {
            this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.dataLength})});
            JCudaDriver.cuLaunchKernel(this.function_back, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, int i) {
        try {
            this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * tensor.getOnceSize() * 4)}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData().withByteOffset(i * tensor.getOnceSize() * 4)}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData().withByteOffset(i * tensor.getOnceSize() * 4)}), Pointer.to(new int[]{tensor.getOnceSize()})});
            JCudaDriver.cuLaunchKernel(this.function_back, CAFFE_GET_BLOCKS(tensor.getOnceSize()), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backwardTemp(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        try {
            this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.dataLength})});
            JCudaDriver.cuLaunchKernel(this.function_back_temp, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, int i, int i2) {
        try {
            this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i * 4)}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData().withByteOffset(i * 4)}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData().withByteOffset(i * 4)}), Pointer.to(new int[]{i2})});
            JCudaDriver.cuLaunchKernel(this.function_back, CAFFE_GET_BLOCKS(i2), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void backward(Pointer pointer, Pointer pointer2, Pointer pointer3, int i) {
        try {
            this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{pointer}), Pointer.to(new NativePointerObject[]{pointer2}), Pointer.to(new NativePointerObject[]{pointer3}), Pointer.to(new int[]{i})});
            JCudaDriver.cuLaunchKernel(this.function_back, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        float[] fArr = {1.0f, 2.0f, 3.0f, 4.0f, -5.0f, 6.0f, -7.0f, -8.0f, 9.0f, 10.0f, 11.0f, -12.0f, 13.0f, 14.0f, 15.0f, -16.0f};
        float[] one = MatrixUtils.one(fArr.length);
        Tensor tensor = new Tensor(2, 1, 1, 8, fArr, true);
        Tensor tensor2 = new Tensor(2, 1, 1, 8, true);
        Tensor tensor3 = new Tensor(2, 1, 1, 8, one, true);
        Tensor tensor4 = new Tensor(2, 1, 1, 8, true);
        TanhKernel tanhKernel = new TanhKernel();
        tanhKernel.forward(tensor, tensor2);
        tanhKernel.backward(tensor2, tensor3, tensor4);
        tensor2.showDM();
        tensor4.showDM();
        tanhKernel.backward(tensor2, tensor3, tensor4);
        tensor4.showDM();
        CUDAMemoryManager.free();
    }
}
