package com.omega.engine.loss.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.gpu.CUDAModules;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/loss/gpu/CrossEntropyKernel.class */
public class CrossEntropyKernel extends BaseKernel {
    private CUfunction loss_function;
    private CUfunction nl_loss_function;
    private CUfunction log_softmax_nl_loss_function;
    private CUfunction log_softmax_nl_loss_igonre_function;
    private CUfunction check_function;
    private CUfunction loss_backward_function;
    private CUfunction loss_igonre_backward_function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private Pointer log_softmax_nl_loss_kernelParameters;
    private Pointer checkParameters;
    private Pointer backKernelParameters;

    public CrossEntropyKernel() {
        init();
    }

    public void initFunction() {
        try {
            if (this.loss_function == null) {
                this.loss_function = CUDAModules.getLocalFunctionByModule("CrossEntropyKernel.cu", "loss");
            }
            if (this.nl_loss_function == null) {
                this.nl_loss_function = CUDAModules.getLocalFunctionByModule("CrossEntropyKernel.cu", "nl_loss");
            }
            if (this.log_softmax_nl_loss_function == null) {
                this.log_softmax_nl_loss_function = CUDAModules.getLocalFunctionByModule("CrossEntropyKernel.cu", "log_softmax_nl_loss");
            }
            if (this.log_softmax_nl_loss_igonre_function == null) {
                this.log_softmax_nl_loss_igonre_function = CUDAModules.getLocalFunctionByModule("CrossEntropyKernel.cu", "log_softmax_nl_loss_igone");
            }
            if (this.check_function == null) {
                this.check_function = CUDAModules.getLocalFunctionByModule("CrossEntropyKernel.cu", "check");
            }
            if (this.loss_backward_function == null) {
                this.loss_backward_function = CUDAModules.getLocalFunctionByModule("CrossEntropyKernel.cu", "loss_back2");
            }
            if (this.loss_igonre_backward_function == null) {
                this.loss_igonre_backward_function = CUDAModules.getLocalFunctionByModule("CrossEntropyKernel.cu", "loss_back_igonre2");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void init() {
        initFunction();
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void forward(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        this.log_softmax_nl_loss_kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.number}), Pointer.to(new int[]{tensor.width})});
        this.N = tensor3.number;
        JCudaDriver.cuLaunchKernel(this.log_softmax_nl_loss_function, tensor.number, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.log_softmax_nl_loss_kernelParameters, (Pointer) null);
    }

    public void forward(Tensor tensor, Tensor tensor2, Tensor tensor3, int i) {
        this.log_softmax_nl_loss_kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.number}), Pointer.to(new int[]{tensor.width}), Pointer.to(new int[]{i})});
        this.N = tensor3.number;
        JCudaDriver.cuLaunchKernel(this.log_softmax_nl_loss_igonre_function, tensor.number, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.log_softmax_nl_loss_kernelParameters, (Pointer) null);
    }

    public void forwardCheck(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        this.checkParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.number}), Pointer.to(new int[]{tensor.width})});
        this.N = tensor3.number;
        JCudaDriver.cuLaunchKernel(this.check_function, tensor.number, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.checkParameters, (Pointer) null);
    }

    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        if (this.backKernelParameters == null) {
            this.backKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.number}), Pointer.to(new int[]{tensor.width})});
        }
        JCudaDriver.cuLaunchKernel(this.loss_backward_function, tensor.number, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backKernelParameters, (Pointer) null);
        if (MatrixOperation.isNaN(tensor3.syncHost())) {
            tensor.showDMByNumber(0);
        }
    }

    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, int i) {
        if (this.backKernelParameters == null || this.BN != tensor.number) {
            this.backKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.number}), Pointer.to(new int[]{tensor.width}), Pointer.to(new int[]{i})});
            this.BN = tensor.number;
        }
        JCudaDriver.cuLaunchKernel(this.loss_igonre_backward_function, tensor.number, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backKernelParameters, (Pointer) null);
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public static void check() {
        Tensor tensor = new Tensor(2, 1, 1, 10, RandomUtils.x2Random(2 * 1 * 1 * 10), true);
        Tensor tensor2 = new Tensor(2, 1, 1, 10, new float[]{0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f}, true);
        Tensor tensor3 = new Tensor(2, 1, 1, 1, 1, true);
        loss_gpu(tensor, tensor2, tensor3);
        System.out.println("gpu:" + JsonUtils.toJson(tensor3.syncHost()));
        for (int i = 0; i < 2; i++) {
            System.out.println(loss_cpu(tensor.getByNumber(i), tensor2.getByNumber(i)));
        }
    }

    public static void loss_gpu(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        new CrossEntropyKernel().forward(tensor, tensor2, tensor3);
    }

    public static float loss_cpu(float[] fArr, float[] fArr2) {
        float f = 0.0f;
        float f2 = 0.0f;
        float max = MatrixOperation.max(fArr);
        for (float f3 : fArr) {
            f = (float) (f + Math.exp(f3 - max));
        }
        for (int i = 0; i < fArr.length; i++) {
            f2 += (float) ((-((fArr[i] - max) - Math.log(f))) * fArr2[i]);
        }
        return f2;
    }

    public static void main(String[] strArr) {
        check();
    }
}
