package com.omega.engine.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/gpu/BaseKernel.class */
public class BaseKernel {
    public int N = 0;
    public int BN = 0;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private CUfunction copy_gpu_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "copy_kernel");
    private CUfunction axpy_gpu_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "axpy_kernel");
    private CUfunction fill_gpu_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "fill_kernel");
    private CUfunction scal_add_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "scal_add_kernel");
    private CUfunction constrain_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "constrain_kernel");

    public void constrain_gpu(int i, float f, Tensor tensor, int i2) {
        try {
            if (this.constrain_function == null) {
                this.constrain_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "constrain_kernel");
            }
            checkCUDA(JCudaDriver.cuLaunchKernel(this.constrain_function, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i}), Pointer.to(new float[]{f}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{i2})}), (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void constrain_gpu(int i, float f, Tensor tensor, int i2, int i3) {
        try {
            if (this.constrain_function == null) {
                this.constrain_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "constrain_kernel");
            }
            checkCUDA(JCudaDriver.cuLaunchKernel(this.constrain_function, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i}), Pointer.to(new float[]{f}), Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i3 * 4)}), Pointer.to(new int[]{i2})}), (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fill_gpu(Tensor tensor, float f) {
        try {
            if (this.fill_gpu_function == null) {
                this.fill_gpu_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "fill_kernel");
            }
            checkCUDA(JCudaDriver.cuLaunchKernel(this.fill_gpu_function, CAFFE_GET_BLOCKS(tensor.getDataLength()), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{tensor.getDataLength()}), Pointer.to(new float[]{f}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{1})}), (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void scal_add_gpu(Tensor tensor, int i, float f, float f2, int i2, int i3) {
        try {
            checkCUDA(JCudaDriver.cuLaunchKernel(this.scal_add_function, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i}), Pointer.to(new float[]{f}), Pointer.to(new float[]{f2}), Pointer.to(new NativePointerObject[]{tensor.getGpuData().withByteOffset(i2 * 4)}), Pointer.to(new int[]{i3})}), (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void axpy_gpu(Tensor tensor, Tensor tensor2, int i, float f, int i2, int i3) {
        axpy_gpu(tensor, tensor2, i3, f, 0, i2, 0, i3);
    }

    public void axpy_gpu(Tensor tensor, Tensor tensor2, int i, float f, int i2, int i3, int i4, int i5) {
        try {
            if (this.axpy_gpu_function == null) {
                this.axpy_gpu_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "axpy_kernel");
            }
            checkCUDA(JCudaDriver.cuLaunchKernel(this.axpy_gpu_function, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i}), Pointer.to(new float[]{f}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{i2}), Pointer.to(new int[]{i3}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{i4}), Pointer.to(new int[]{i5})}), (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void copy_gpu(Pointer pointer, Pointer pointer2, int i, int i2, int i3) {
        try {
            if (this.copy_gpu_function == null) {
                this.copy_gpu_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "copy_kernel");
            }
            checkCUDA(JCudaDriver.cuLaunchKernel(this.copy_gpu_function, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i}), Pointer.to(new NativePointerObject[]{pointer}), Pointer.to(new int[]{0}), Pointer.to(new int[]{i2}), Pointer.to(new NativePointerObject[]{pointer2}), Pointer.to(new int[]{0}), Pointer.to(new int[]{i3})}), (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void copy_gpu(Tensor tensor, Tensor tensor2, int i, int i2, int i3) {
        try {
            if (this.copy_gpu_function == null) {
                this.copy_gpu_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "copy_kernel");
            }
            JCudaDriver.cuLaunchKernel(this.copy_gpu_function, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{0}), Pointer.to(new int[]{i2}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{0}), Pointer.to(new int[]{i3})}), (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void copy_gpu(Tensor tensor, Tensor tensor2, int i, int i2, int i3, int i4, int i5) {
        try {
            if (this.copy_gpu_function == null) {
                this.copy_gpu_function = CUDAModules.getLocalFunctionByModule("BaseKernel.cu", "copy_kernel");
            }
            JCudaDriver.cuLaunchKernel(this.copy_gpu_function, CAFFE_GET_BLOCKS(i), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{i2}), Pointer.to(new int[]{i3}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{i4}), Pointer.to(new int[]{i5})}), (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void showDM(Pointer pointer, float[] fArr) {
        JCuda.cudaMemcpy(Pointer.to(fArr), pointer, fArr.length * 4, 2);
        System.out.println(JsonUtils.toJson(fArr));
    }

    public void showDM(Pointer pointer, int[] iArr) {
        JCuda.cudaMemcpy(Pointer.to(iArr), pointer, iArr.length * 4, 2);
        System.out.println(JsonUtils.toJson(iArr));
    }

    public void showDM(Pointer pointer, int i) {
        float[] fArr = new float[i];
        JCuda.cudaMemcpy(Pointer.to(fArr), pointer, fArr.length * 4, 2);
        System.out.println(JsonUtils.toJson(fArr));
    }

    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }
}
