package com.omega.engine.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.RandomUtils;
import com.omega.example.transformer.utils.ENTokenizer;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/gpu/SoftmaxKernel.class */
public class SoftmaxKernel extends BaseKernel {
    private CUfunction softmax_function;
    private CUfunction softmax_mask_function;
    private CUfunction log_softmax_function;
    private CUfunction softmax_backward_function;
    private CUfunction softmax_mask_backward_function;
    private CUfunction log_softmax_backward_function;
    private CUfunction log_softmax_backward_function2;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private Pointer kernelParameters;
    private Pointer kernelMaskParameters;
    private Pointer backKernelParameters;
    private Pointer backKernelParameters2;

    public SoftmaxKernel() {
        init();
    }

    public void initFunction() {
        try {
            if (this.softmax_function == null) {
                this.softmax_function = CUDAModules.getLocalFunctionByModule("SoftmaxKernel.cu", "softmax");
            }
            if (this.softmax_mask_function == null) {
                this.softmax_mask_function = CUDAModules.getLocalFunctionByModule("SoftmaxKernel.cu", "softmax_mask");
            }
            if (this.log_softmax_function == null) {
                this.log_softmax_function = CUDAModules.getLocalFunctionByModule("SoftmaxKernel.cu", "log_softmax");
            }
            if (this.softmax_backward_function == null) {
                this.softmax_backward_function = CUDAModules.getLocalFunctionByModule("SoftmaxKernel.cu", "softmax_back");
            }
            if (this.log_softmax_backward_function == null) {
                this.log_softmax_backward_function = CUDAModules.getLocalFunctionByModule("SoftmaxKernel.cu", "log_softmax_back");
            }
            if (this.log_softmax_backward_function2 == null) {
                this.log_softmax_backward_function2 = CUDAModules.getLocalFunctionByModule("SoftmaxKernel.cu", "log_softmax_back2");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void init() {
        initFunction();
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void softmax(Tensor tensor, Tensor tensor2) {
        if (this.kernelParameters == null || this.N != tensor2.number) {
            this.kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor.number * tensor.channel * tensor.height}), Pointer.to(new int[]{tensor.width})});
            this.N = tensor2.number;
        }
        JCudaDriver.cuLaunchKernel(this.softmax_function, CAFFE_GET_BLOCKS(tensor.number * tensor.channel * tensor.height), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelParameters, (Pointer) null);
    }

    public void softmax_out(Tensor tensor, Tensor tensor2) {
        this.kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor.number * tensor.channel * tensor.height}), Pointer.to(new int[]{tensor.width})});
        JCudaDriver.cuLaunchKernel(this.softmax_function, CAFFE_GET_BLOCKS(tensor.number * tensor.channel * tensor.height), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelParameters, (Pointer) null);
    }

    public void softmaxMask(Tensor tensor, Tensor tensor2, Tensor tensor3, float f) {
        this.kernelMaskParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor.number * tensor.channel * tensor3.height}), Pointer.to(new int[]{tensor.width}), Pointer.to(new float[]{f})});
        this.N = tensor3.number;
        JCudaDriver.cuLaunchKernel(this.softmax_mask_function, CAFFE_GET_BLOCKS(tensor.number * tensor.channel * tensor3.height), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelMaskParameters, (Pointer) null);
    }

    public void log_softmax(Tensor tensor, Tensor tensor2) {
        if (this.kernelParameters == null || this.N != tensor2.number) {
            this.kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{tensor.number}), Pointer.to(new int[]{tensor.width})});
            this.N = tensor2.number;
        }
        JCudaDriver.cuLaunchKernel(this.log_softmax_function, tensor.number, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelParameters, (Pointer) null);
    }

    public void backward_noloss(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        if (this.backKernelParameters == null) {
            this.backKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor.number * tensor.channel * tensor.height}), Pointer.to(new int[]{tensor.width})});
        }
        JCudaDriver.cuLaunchKernel(this.softmax_backward_function, CAFFE_GET_BLOCKS(tensor.number * tensor.channel * tensor.height), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backKernelParameters, (Pointer) null);
    }

    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        if (this.backKernelParameters == null) {
            this.backKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor3.dataLength})});
        }
        JCudaDriver.cuLaunchKernel(this.log_softmax_backward_function, CAFFE_GET_BLOCKS(tensor3.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backKernelParameters, (Pointer) null);
    }

    public void backward2(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        if (this.backKernelParameters2 == null) {
            this.backKernelParameters2 = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{tensor3.dataLength}), Pointer.to(new int[]{this.N})});
        }
        JCudaDriver.cuLaunchKernel(this.log_softmax_backward_function2, CAFFE_GET_BLOCKS(tensor3.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backKernelParameters2, (Pointer) null);
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public void cpuForward(Tensor tensor, Tensor tensor2) {
        float[] fArr = new float[tensor.channel * tensor.height * tensor.width];
        for (int i = 0; i < tensor.number; i++) {
            tensor.copy(i, fArr);
            float[] exp = MatrixOperation.exp(MatrixOperation.subtraction(fArr, MatrixOperation.max(fArr)));
            float sum = MatrixOperation.sum(exp);
            for (int i2 = 0; i2 < exp.length; i2++) {
                tensor2.data[(i * tensor2.channel * tensor2.height * tensor2.width) + i2] = exp[i2] / sum;
            }
        }
    }

    public void cpuForward2(Tensor tensor, Tensor tensor2) {
        for (int i = 0; i < tensor.number; i++) {
            float f = -3.4028235E38f;
            float f2 = 0.0f;
            for (int i2 = 0; i2 < tensor.width; i2++) {
                if (f <= tensor.data[(i * tensor.width) + i2]) {
                    f = tensor.data[(i * tensor.width) + i2];
                }
            }
            for (int i3 = 0; i3 < tensor.width; i3++) {
                float exp = (float) Math.exp(tensor.data[(i * tensor.width) + i3] - f);
                f2 += exp;
                tensor2.data[(i * tensor.width) + i3] = exp;
            }
            for (int i4 = 0; i4 < tensor.width; i4++) {
                float[] fArr = tensor2.data;
                int i5 = (i * tensor.width) + i4;
                fArr[i5] = fArr[i5] / f2;
            }
        }
    }

    public void cpuForwardMask(Tensor tensor, Tensor tensor2, Tensor tensor3, float f) {
        int i = tensor.height * tensor.width;
        for (int i2 = 0; i2 < tensor.number * tensor.channel; i2++) {
            float f2 = -3.4028235E38f;
            float f3 = 0.0f;
            int i3 = i2 / tensor.channel;
            for (int i4 = 0; i4 < i; i4++) {
                float f4 = tensor.data[(i2 * i) + i4];
                if (tensor3.data[(i3 * i) + i4] == 1.0f) {
                    f4 = f;
                }
                if (f2 <= f4) {
                    f2 = f4;
                }
            }
            for (int i5 = 0; i5 < i; i5++) {
                float f5 = tensor.data[(i2 * i) + i5];
                if (tensor3.data[(i3 * i) + i5] == 1.0f) {
                    f5 = f;
                }
                float exp = (float) Math.exp(f5 - f2);
                f3 += exp;
                tensor2.data[(i2 * i) + i5] = exp;
            }
            for (int i6 = 0; i6 < i; i6++) {
                float[] fArr = tensor2.data;
                int i7 = (i2 * i) + i6;
                fArr[i7] = fArr[i7] / f3;
            }
        }
    }

    public void cpuBackward(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        for (int i = 0; i < tensor.getDataLength(); i++) {
            tensor3.data[i] = tensor.data[i] - tensor2.data[i];
        }
    }

    public static void cpuBackwardNoLoss(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        for (int i = 0; i < tensor.number; i++) {
            float f = 0.0f;
            for (int i2 = 0; i2 < tensor.width; i2++) {
                f += tensor.data[(i * tensor.width) + i2] * tensor2.data[(i * tensor.width) + i2];
            }
            for (int i3 = 0; i3 < tensor.width; i3++) {
                tensor3.data[(i * tensor.width) + i3] = (tensor2.data[(i * tensor.width) + i3] - f) * tensor.data[(i * tensor.width) + i3];
            }
        }
        tensor3.hostToDevice();
    }

    public static void main(String[] strArr) {
        float[] order = RandomUtils.order(2 * 5 * 4 * 4, 0.1f, 0.1f);
        Tensor triu = ENTokenizer.triu(2, 5, 4, 4, 1.0f);
        triu.showDM();
        Tensor tensor = new Tensor(2, 5, 4, 4, order, true);
        Tensor tensor2 = new Tensor(2, 5, 4, 4, true);
        Tensor tensor3 = new Tensor(2, 5, 4, 4);
        SoftmaxKernel softmaxKernel = new SoftmaxKernel();
        softmaxKernel.softmaxMask(tensor, triu, tensor2, -1.0E9f);
        tensor2.showDM();
        softmaxKernel.cpuForwardMask(tensor, tensor3, triu, -1.0E9f);
        System.out.println("output2:" + JsonUtils.toJson(tensor3.data));
        Tensor tensor4 = new Tensor(2, 5, 4, 4, RandomUtils.order(2 * 5 * 4 * 4, 0.1f, 0.0f), true);
        Tensor tensor5 = new Tensor(2, 5, 4, 4, true);
        softmaxKernel.backward_noloss(tensor2, tensor4, tensor5);
        tensor5.showDM();
    }
}
