package com.omega.engine.nn.layer.normalization.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.gpu.GPUOP;
import com.omega.engine.nn.layer.normalization.BNType;
import com.omega.engine.nn.layer.normalization.LNLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.Transformer;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/nn/layer/normalization/gpu/LNKernel.class */
public class LNKernel extends BaseKernel {
    public BNType bnType;
    private int B;
    private int W;
    private CUfunction forward_test_function;
    private CUfunction forward_aten_function;
    private CUfunction mean_var_function;
    private CUfunction fused_params_function;
    private CUfunction forward_fused_function;
    private CUfunction inter_grad_function;
    private CUfunction backward_fused_function;
    private CUfunction ln_backward_function;
    private CUfunction forward_llm_function;
    private CUfunction backward_llm_function;
    private CUfunction backward_aten_function;
    private CUfunction backward_aten_function2;
    private CUfunction backward_aten_gamma_function2;
    private CUfunction backward_ig_function;
    private CUfunction backward_fp_function;
    private CUfunction backward_input_function;
    private CUfunction backward_gamma_function;
    private CUfunction backward_gamma_simple_function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private float eta = 1.0E-5f;
    private int kCUDABlockReduceNumThreads = 512;
    private int kCUDANumThreads = 256;
    private int kColwiseReduceTileSize = 32;
    private Pointer forwardTestParameters;
    private Pointer forwardAtenParameters;
    private Pointer meanVarParameters;
    private Pointer fusedParameters;
    private Pointer forwardFusedParams;
    private Pointer interGradParameters;
    private Pointer backwardFusedParameters;
    private Pointer lnBKParameters;
    private Pointer forwardLLMParameters;
    private Pointer backwardLLMParameters;
    private Pointer backwardAtenParameters;
    private Pointer backwardAtenGammaParameters2;
    private Pointer backwardIGParameters;
    private Pointer backwardFGParameters;
    private Pointer backwardInputParameters;
    private Pointer backwardGammaParameters;
    private Pointer backwardGammaSampleParameters;
    private CUdeviceptr d_mean;
    private CUdeviceptr d_var;
    private CUdeviceptr scratch;
    private CUdeviceptr d_s;
    private CUdeviceptr d_b;
    private CUdeviceptr d_scale;
    private CUdeviceptr d_bias;
    private Pointer aten_mean;
    private Pointer aten_var;
    private Tensor mean;
    private Tensor simga;
    private Tensor scale;
    private Tensor bias;
    private Tensor ds;
    private Tensor db;
    private Tensor rstd;
    private Tensor g_scale;
    private Tensor X_scale;
    private Tensor ones;

    public LNKernel(int i, BNType bNType) {
        this.bnType = null;
        this.W = i;
        this.bnType = bNType;
        init();
    }

    private void initKernel() {
        if (this.aten_mean != null) {
            CUDAMemoryManager.free(this.aten_mean);
            CUDAMemoryManager.free(this.aten_var);
        }
        this.aten_mean = CUDAMemoryManager.getPointer(this.B);
        this.aten_var = CUDAMemoryManager.getPointer(this.B);
        this.scratch = CUDAMemoryManager.getDevice((this.W * 2) + 1);
    }

    public void initFunction() {
        try {
            if (this.forward_llm_function == null) {
                this.forward_llm_function = CUDAModules.getLocalFunctionByModule("LNKernel.cu", "layernorm_forward_kernel5");
            }
            if (this.backward_llm_function == null) {
                this.backward_llm_function = CUDAModules.getLocalFunctionByModule("LNKernel.cu", "layernorm_backward_kernel7");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void init() {
        initFunction();
    }

    public boolean checkBatch(Tensor tensor) {
        int i = 0;
        switch (this.bnType) {
            case fully_bn:
                i = tensor.number * tensor.channel * tensor.height;
                break;
            case conv_bn:
                i = tensor.number * tensor.channel;
                break;
        }
        if (this.B == i) {
            return true;
        }
        this.B = i;
        return false;
    }

    public void initBackward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        if (this.backwardInputParameters == null) {
            this.backwardIGParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_s}), Pointer.to(new NativePointerObject[]{this.d_b})});
            this.backwardFGParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.B}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_s}), Pointer.to(new NativePointerObject[]{this.d_b}), Pointer.to(new NativePointerObject[]{this.d_scale}), Pointer.to(new NativePointerObject[]{this.d_bias})});
            this.backwardInputParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.W})});
            this.backwardGammaSampleParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.B}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor6.getGpuData()})});
            this.backwardGammaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.B}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor6.getGpuData()})});
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void forward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            if (!checkBatch(tensor3)) {
                initKernel();
                this.forwardTestParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.W}), Pointer.to(new float[]{this.eta}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()})});
            }
            int[] iArr = {this.B, 512};
            int[] iArr2 = {this.B, 256};
            JCudaDriver.cuLaunchKernel(this.forward_test_function, Math.max(iArr[0], iArr2[0]), 1, 1, Math.max(iArr[1], iArr2[1]), 1, 1, 0, (CUstream) null, this.forwardTestParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void forwardAten(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            if (!checkBatch(tensor3)) {
                initKernel();
                Pointer pointer = null;
                if (tensor2 != null) {
                    pointer = tensor2.getGpuData();
                }
                this.forwardAtenParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.W}), Pointer.to(new float[]{this.eta}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{pointer}), Pointer.to(new NativePointerObject[]{this.aten_mean}), Pointer.to(new NativePointerObject[]{this.aten_var}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()})});
            }
            int[] iArr = {32, 256 / 32, 1};
            int[] iArr2 = {this.B, 1, 1};
            checkCUDA(JCudaDriver.cuLaunchKernel(this.forward_aten_function, iArr2[0], iArr2[1], iArr2[2], iArr[0], iArr[1], iArr[2], iArr[1] > 1 ? ((iArr[1] * 3) / 2) * 4 : 0, (CUstream) null, this.forwardAtenParameters, (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void meanVar(Tensor tensor) {
        try {
            if (!checkBatch(tensor)) {
                initKernel();
                this.meanVarParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.mean.getGpuData()}), Pointer.to(new NativePointerObject[]{this.simga.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.mean_var_function, this.B, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.meanVarParameters, (Pointer) null);
            System.err.println("mean:");
            this.mean.showDM(0);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fusedParams() {
        try {
            if (this.fusedParameters == null) {
                this.fusedParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.B}), Pointer.to(new float[]{this.eta}), Pointer.to(new NativePointerObject[]{this.mean.getGpuData()}), Pointer.to(new NativePointerObject[]{this.simga.getGpuData()}), Pointer.to(new NativePointerObject[]{this.simga.getGpuData()}), Pointer.to(new NativePointerObject[]{this.scale.getGpuData()}), Pointer.to(new NativePointerObject[]{this.bias.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.fused_params_function, CAFFE_GET_BLOCKS(this.B), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fusedParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void forwardFused(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            if (this.forwardFusedParams == null) {
                this.forwardFusedParams = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.B}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.scale.getGpuData()}), Pointer.to(new NativePointerObject[]{this.bias.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.forward_fused_function, CAFFE_GET_BLOCKS(this.B * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardFusedParams, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void forward_aten(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            meanVar(tensor3);
            fusedParams();
            forwardFused(tensor3, tensor, tensor2, tensor4);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void forward_llm(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            if (!checkBatch(tensor3)) {
                initKernel();
            }
            this.forwardLLMParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.aten_mean}), Pointer.to(new NativePointerObject[]{this.aten_var}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{this.B}), Pointer.to(new int[]{this.W})});
            checkCUDA(JCudaDriver.cuLaunchKernel(this.forward_llm_function, this.B, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardLLMParameters, (Pointer) null));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backward_llm(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        try {
            tensor3.clearGPU();
            checkCUDA(JCuda.cudaMemset(this.scratch, 0, (1 + (2 * this.W)) * 4));
            this.backwardLLMParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor6.getGpuData()}), Pointer.to(new NativePointerObject[]{this.scratch}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.aten_mean}), Pointer.to(new NativePointerObject[]{this.aten_var}), Pointer.to(new int[]{this.B}), Pointer.to(new int[]{1}), Pointer.to(new int[]{this.W})});
            JCudaDriver.cuLaunchKernel(this.backward_llm_function, (1024 / this.CAFFE_CUDA_NUM_THREADS) * CUDAModules.props.multiProcessorCount, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, ((2 * this.W) + 1) * 4, (CUstream) null, this.backwardLLMParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        try {
            initBackward(tensor, tensor2, tensor3, tensor4, tensor5, tensor6);
            int i = this.B;
            int i2 = this.W;
            JCudaDriver.cuLaunchKernel(this.backward_ig_function, i, 1, 1, this.kCUDABlockReduceNumThreads, 1, 1, 0, (CUstream) null, this.backwardIGParameters, (Pointer) null);
            JCudaDriver.cuLaunchKernel(this.backward_fp_function, ((i + this.kCUDANumThreads) - 1) / this.kCUDANumThreads, 1, 1, this.kCUDANumThreads, 1, 1, 0, (CUstream) null, this.backwardFGParameters, (Pointer) null);
            JCudaDriver.cuLaunchKernel(this.backward_input_function, i, 1, 1, 128, 1, 1, (128 / 32) * 4, (CUstream) null, this.backwardInputParameters, (Pointer) null);
            if (i < 512) {
                JCudaDriver.cuLaunchKernel(this.backward_gamma_simple_function, ((i2 + this.kCUDANumThreads) - 1) / this.kCUDANumThreads, 1, 1, this.kCUDANumThreads, 1, 1, 0, (CUstream) null, this.backwardGammaSampleParameters, (Pointer) null);
            } else {
                JCudaDriver.cuLaunchKernel(this.backward_gamma_function, ((i2 + this.kColwiseReduceTileSize) - 1) / this.kColwiseReduceTileSize, 1, 1, this.kColwiseReduceTileSize, this.kColwiseReduceTileSize / 2, 1, 0, (CUstream) null, this.backwardGammaParameters, (Pointer) null);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backwardAtenGamma(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            if (this.backwardAtenGammaParameters2 == null) {
                Pointer pointer = null;
                if (tensor4 != null) {
                    pointer = tensor4.getGpuData();
                }
                this.backwardAtenGammaParameters2 = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.B}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.aten_mean}), Pointer.to(new NativePointerObject[]{this.aten_var}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{pointer})});
            }
            int[] iArr = {16, 32, 1};
            JCudaDriver.cuLaunchKernel(this.backward_aten_gamma_function2, ((this.W + iArr[0]) - 1) / iArr[0], 1, 1, iArr[0], iArr[1], iArr[2], 8 * iArr[0] * iArr[1], (CUstream) null, this.backwardAtenGammaParameters2, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backwardAten(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        try {
            backwardAtenGamma(tensor, tensor2, tensor5, tensor6);
            if (this.backwardAtenParameters == null) {
                this.backwardAtenParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.aten_mean}), Pointer.to(new NativePointerObject[]{this.aten_var}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.W})});
            }
            int[] iArr = {this.B, 1, 1};
            JCudaDriver.cuLaunchKernel(this.backward_aten_function, iArr[0], iArr[1], iArr[2], 256, 1, 1, 32, (CUstream) null, this.backwardAtenParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void interGrad(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            TensorOP.mul(tensor, tensor2, tensor4);
            if (this.interGradParameters == null) {
                this.interGradParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.ds.getGpuData()}), Pointer.to(new NativePointerObject[]{this.db.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.inter_grad_function, this.B, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.interGradParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backwardFusedParams() {
        try {
            if (this.backwardFusedParameters == null) {
                this.backwardFusedParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.B}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{this.mean.getGpuData()}), Pointer.to(new NativePointerObject[]{this.simga.getGpuData()}), Pointer.to(new NativePointerObject[]{this.ds.getGpuData()}), Pointer.to(new NativePointerObject[]{this.db.getGpuData()}), Pointer.to(new NativePointerObject[]{this.rstd.getGpuData()}), Pointer.to(new NativePointerObject[]{this.X_scale.getGpuData()}), Pointer.to(new NativePointerObject[]{this.bias.getGpuData()}), Pointer.to(new NativePointerObject[]{this.g_scale.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.backward_fused_function, CAFFE_GET_BLOCKS(this.B), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardFusedParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void gammaBetaBackward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            GPUOP.getInstance().gemv(0, this.B, this.W, tensor, this.rstd, tensor3, 1.0f, 0.0f);
            GPUOP.getInstance().gemv(0, this.B, this.W, tensor2, this.g_scale, tensor3, 1.0f, 1.0f);
            GPUOP.getInstance().gemv(0, this.B, this.W, tensor2, this.ones, tensor4, 1.0f, 0.0f);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void lnBackward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            if (this.lnBKParameters == null) {
                this.lnBKParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.B}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.rstd.getGpuData()}), Pointer.to(new NativePointerObject[]{this.X_scale.getGpuData()}), Pointer.to(new NativePointerObject[]{this.bias.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.ln_backward_function, CAFFE_GET_BLOCKS(this.B), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.lnBKParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void backward_aten(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        interGrad(tensor2, tensor, tensor4, tensor3);
        backwardFusedParams();
        gammaBetaBackward(tensor3, tensor2, tensor5, tensor6);
        lnBackward(tensor2, tensor, tensor4, tensor3);
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
            throw new RuntimeException("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public void showDM(String str, CUdeviceptr cUdeviceptr, float[] fArr) {
        JCudaDriver.cuMemcpyDtoH(Pointer.to(fArr), cUdeviceptr, fArr.length * 4);
        System.out.println(str + ":" + JsonUtils.toJson(fArr));
    }

    public static void main(String[] strArr) {
        Tensor tensor = new Tensor(10 * 1, 1, 1, 3, RandomUtils.order(10 * 1 * 3, 0.1f, 0.1f), true);
        Tensor tensor2 = new Tensor(10 * 1, 1, 1, 3, MatrixUtils.order(10 * 1 * 3, 0.1f, 0.1f), true);
        new Tensor(128, 1, 1, 3, RandomUtils.order(128 * 3, 0.1f, 0.1f), true);
        LNLayer lNLayer = new LNLayer((Network) new Transformer(), true);
        for (int i = 0; i < 10; i++) {
            lNLayer.forward(tensor);
            lNLayer.getOutput().showDM();
            lNLayer.back(tensor2);
            lNLayer.diff.showDM();
        }
    }
}
