package com.omega.engine.nn.layer.normalization.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.CheckArrayUtils;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.PrintUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.nn.layer.gpu.BNBaseKernel;
import com.omega.engine.nn.layer.normalization.BNType;
import com.omega.engine.nn.network.RunModel;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/nn/layer/normalization/gpu/BNKernel2.class */
public class BNKernel2 extends BNBaseKernel {
    private BNType bnType;
    private int C;
    private int H;
    private int W;
    private int meanNumber;
    private CUfunction mean_function;
    private CUfunction var_function;
    private CUfunction mwa_function;
    private CUfunction culOutput_function;
    private CUfunction fast_mean_function;
    private CUfunction fast_var_function;
    private CUfunction normalize_function;
    private CUfunction normalize_test_function;
    private CUfunction computeDiff_function;
    private CUfunction computeDelta_full_function;
    private CUfunction fast_mean_xhat_function;
    private CUfunction fast_mean_dxhat_function;
    private CUfunction mean_xhat_function;
    private CUfunction meanDzSum_function;
    private CUfunction dgama_function;
    private CUfunction dbeta_function;
    private CUfunction dxhat_function;
    private CUfunction dx_function;
    private CUfunction dx_full_function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private float eta = 1.0E-5f;
    private float momentum = 0.01f;
    private CUdeviceptr d_z;
    private CUdeviceptr d_mean;
    private CUdeviceptr d_var;
    private CUdeviceptr d_std;
    private CUdeviceptr d_runingMean;
    private CUdeviceptr d_runingVar;
    private CUdeviceptr d_mean_dz;
    private CUdeviceptr d_mean_dzxz;
    private Pointer meanParameters;
    private Pointer varParameters;
    private Pointer fastMeanParameters;
    private Pointer fastVarParameters;
    private Pointer normalizeParameters;
    private Pointer mwaParameters;
    private Pointer dgamaParameters;
    private Pointer dbetaParameters;
    private Pointer computeDelta_full_Parameters;
    private Pointer normalize_test_Parameters;
    private Pointer dxhatParameters;
    private Pointer dxParameters;
    private Pointer fast_mean_xhat_Parameters;
    private Pointer fast_mean_dxhat_Parameters;
    private Pointer mean_xhat_Parameters;
    private Pointer dx_fullParameters;

    public BNKernel2(BNType bNType, int i, int i2, int i3) {
        this.meanNumber = 0;
        this.bnType = bNType;
        this.C = i;
        this.H = i2;
        this.W = i3;
        if (this.bnType == BNType.fully_bn) {
            this.meanNumber = i3;
        } else {
            this.meanNumber = i;
        }
        init();
    }

    public void initFunction() {
        try {
            if (this.mean_function == null) {
                this.mean_function = CUDAModules.getLocalFunctionByModule("MathKernel2.cu", "mean_full");
            }
            if (this.fast_mean_function == null) {
                this.fast_mean_function = CUDAModules.getLocalFunctionByModule("MathKernel2.cu", "fast_mean_kernel");
            }
            if (this.var_function == null) {
                this.var_function = CUDAModules.getLocalFunctionByModule("MathKernel2.cu", "var_full");
            }
            if (this.fast_var_function == null) {
                this.fast_var_function = CUDAModules.getLocalFunctionByModule("MathKernel2.cu", "fast_variance_kernel");
            }
            if (this.normalize_function == null) {
                this.normalize_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "normalize_kernel");
            }
            if (this.normalize_test_function == null) {
                this.normalize_test_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "normalize_test_kernel");
            }
            if (this.mwa_function == null) {
                this.mwa_function = CUDAModules.getLocalFunctionByModule("MathKernel2.cu", "mwa");
            }
            if (this.culOutput_function == null) {
                this.culOutput_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "culOutput_cov");
            }
            if (this.computeDelta_full_function == null) {
                this.computeDelta_full_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "computeDelta_full");
            }
            if (this.meanDzSum_function == null) {
                this.meanDzSum_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "meanDzSum");
            }
            if (this.computeDiff_function == null) {
                this.computeDiff_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "computeDiff");
            }
            if (this.dgama_function == null) {
                this.dgama_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "dgama_kernel");
            }
            if (this.dbeta_function == null) {
                this.dbeta_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "dbeta_kernel");
            }
            if (this.dxhat_function == null) {
                this.dxhat_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "dxhat_kernel");
            }
            if (this.dx_function == null) {
                this.dx_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "dx_kernel");
            }
            if (this.dx_full_function == null) {
                this.dx_full_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "dx_kernel_full");
            }
            if (this.fast_mean_xhat_function == null) {
                this.fast_mean_xhat_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "fast_mean_xhat_kernel");
            }
            if (this.mean_xhat_function == null) {
                this.mean_xhat_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "mean_xhat_kernel");
            }
            if (this.fast_mean_dxhat_function == null) {
                this.fast_mean_dxhat_function = CUDAModules.getLocalFunctionByModule("BNKernel2.cu", "fast_mean_dxhat_kernel");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void initKernel() {
        this.d_mean = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_var = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_std = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_runingMean = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_runingVar = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_mean_dz = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_mean_dzxz = CUDAMemoryManager.getDevice(this.meanNumber);
    }

    public void init() {
        initFunction();
        initKernel();
    }

    public void initForward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        if (tensor.number != this.N) {
            this.N = tensor.number;
            if (this.bnType == BNType.fully_bn) {
                this.meanParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W})});
                this.varParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_std}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W})});
            } else {
                this.fastMeanParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_mean})});
                this.fastVarParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_std})});
            }
            this.mwaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_runingMean}), Pointer.to(new NativePointerObject[]{this.d_runingVar}), Pointer.to(new int[]{this.meanNumber}), Pointer.to(new float[]{this.momentum})});
            int i = 1;
            if (this.bnType == BNType.conv_bn) {
                i = this.H * this.W;
            }
            this.d_z = CUDAMemoryManager.getDevice(this.N * this.C * this.H * this.W);
            this.normalizeParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.meanNumber}), Pointer.to(new int[]{i}), Pointer.to(new float[]{this.eta})});
            this.normalize_test_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_runingMean}), Pointer.to(new NativePointerObject[]{this.d_runingVar}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.meanNumber}), Pointer.to(new int[]{i}), Pointer.to(new float[]{this.eta})});
        }
    }

    public void initBackward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        if (this.dgamaParameters == null) {
            if (this.bnType == BNType.fully_bn) {
                this.computeDelta_full_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor6.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W})});
                this.dx_fullParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{this.d_std}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean_dz}), Pointer.to(new NativePointerObject[]{this.d_mean_dzxz}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C})});
            } else {
                this.dxParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{this.d_std}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean_dz}), Pointer.to(new NativePointerObject[]{this.d_mean_dzxz}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W})});
                this.dgamaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()})});
                this.dbetaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor6.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W})});
            }
            int i = 1;
            if (this.bnType == BNType.conv_bn) {
                i = this.H * this.W;
            }
            this.fast_mean_xhat_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{i}), Pointer.to(new NativePointerObject[]{this.d_mean_dz})});
            this.fast_mean_dxhat_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{i}), Pointer.to(new NativePointerObject[]{this.d_mean_dzxz})});
            this.mean_xhat_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean_dz}), Pointer.to(new NativePointerObject[]{this.d_mean_dzxz}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W})});
            this.dxhatParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new int[]{this.meanNumber}), Pointer.to(new int[]{i})});
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void forward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        initForward(runModel, tensor3, tensor, tensor2, tensor4);
        if (runModel != RunModel.TRAIN) {
            normalize_test(tensor3, tensor, tensor2, tensor4);
            return;
        }
        if (this.bnType == BNType.fully_bn) {
            mean();
            var();
        } else {
            fast_mean();
            fast_var();
        }
        mwa();
        normalize_train(tensor3, tensor, tensor2, tensor4);
    }

    public void mean() {
        try {
            JCudaDriver.cuLaunchKernel(this.mean_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.meanParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void var() {
        try {
            JCudaDriver.cuLaunchKernel(this.var_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.varParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fast_mean() {
        try {
            JCudaDriver.cuLaunchKernel(this.fast_mean_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastMeanParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fast_var() {
        try {
            JCudaDriver.cuLaunchKernel(this.fast_var_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastVarParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void normalize_test(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            JCudaDriver.cuLaunchKernel(this.normalize_test_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.normalize_test_Parameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void normalize_train(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            JCudaDriver.cuLaunchKernel(this.normalize_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.normalizeParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void mwa() {
        try {
            JCudaDriver.cuLaunchKernel(this.mwa_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.mwaParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        initBackward(tensor, tensor2, tensor3, tensor4, tensor5, tensor6);
        if (this.bnType == BNType.fully_bn) {
            computeDelta_full();
        } else {
            computeDgama();
            computeDbeta();
        }
        computeDxhat();
        computeMeanXhat2();
        if (this.bnType == BNType.fully_bn) {
            computeDx_full();
        } else {
            computeDx();
        }
    }

    private void computeDgama() {
        JCudaDriver.cuLaunchKernel(this.dgama_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dgamaParameters, (Pointer) null);
    }

    private void computeDbeta() {
        JCudaDriver.cuLaunchKernel(this.dbeta_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dbetaParameters, (Pointer) null);
    }

    private void computeDelta_full() {
        JCudaDriver.cuLaunchKernel(this.computeDelta_full_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.computeDelta_full_Parameters, (Pointer) null);
    }

    private void computeDxhat() {
        JCudaDriver.cuLaunchKernel(this.dxhat_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dxhatParameters, (Pointer) null);
    }

    private void computeMeanXhat() {
        JCudaDriver.cuLaunchKernel(this.fast_mean_xhat_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fast_mean_xhat_Parameters, (Pointer) null);
    }

    private void computeMeanDXhat() {
        JCudaDriver.cuLaunchKernel(this.fast_mean_dxhat_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fast_mean_dxhat_Parameters, (Pointer) null);
    }

    private void computeMeanXhat2() {
        JCudaDriver.cuLaunchKernel(this.mean_xhat_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.mean_xhat_Parameters, (Pointer) null);
    }

    private void computeDx() {
        JCudaDriver.cuLaunchKernel(this.dx_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dxParameters, (Pointer) null);
    }

    private void computeDx_full() {
        JCudaDriver.cuLaunchKernel(this.dx_full_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dx_fullParameters, (Pointer) null);
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public void showDM(String str, CUdeviceptr cUdeviceptr, float[] fArr) {
        JCudaDriver.cuMemcpyDtoH(Pointer.to(fArr), cUdeviceptr, fArr.length * 4);
        System.out.println(str + ":" + JsonUtils.toJson(fArr));
    }

    public static void main(String[] strArr) {
        CUDAModules.initContext();
        test2d();
    }

    public static void test2d() {
        float[] fArr = {-0.5596f, 0.6154f, -1.1204f, -0.1636f, -1.3229f, 0.9092f, -0.8235f, 0.3563f, 1.2746f, 0.6454f, -0.7569f, -0.6933f, -1.0571f, -1.5361f, 1.8608f, 0.4835f, -1.3317f, -2.3606f, 0.847f, 1.1632f, 0.508f, -1.4968f, 1.4136f, 0.8903f, 0.42f, 0.8233f, -0.6349f, 0.4416f, 0.5081f, 0.1545f, 0.3967f, 0.6221f, -0.0245f, 0.8493f, -0.8964f, -0.5109f, -0.7737f, -0.2255f, 1.6705f, 0.2935f, 0.5887f, -1.7415f, 0.6597f, 1.2048f, 0.7282f, 0.8854f, 0.9372f, -0.0824f, 1.6266f, -1.945f, -1.3224f, 0.5002f, -1.0779f, 0.9101f, -0.8541f, -0.513f, -0.4204f, 0.1571f, -0.3905f, -1.3593f, 1.0415f, -0.9938f, 2.69f, -1.1995f, 0.7727f, -0.2714f, 1.1784f, -0.8269f, 0.322f, 0.7001f, -0.0134f, -1.1899f, -0.873f, 1.1819f, 0.9492f, -1.3036f, 0.3672f, -0.1123f, -0.1105f, 0.4664f, -0.7766f, -0.1695f, -1.9371f, -0.4164f, 1.8046f, -0.0946f, 0.8305f, 0.982f, 0.566f, -0.1472f, 0.483f, 1.0134f, 1.3013f, -1.4215f, 0.457f, 1.5848f, -0.2974f, -1.216f, 0.6511f, 0.4922f, -1.2987f, -0.9202f, -1.6065f, 0.6146f, -1.7012f, -0.5577f, 1.3336f, -0.5391f, 0.1539f, -0.7145f, 0.2365f, 1.0505f, 0.2315f, -1.4901f, 0.1007f, 0.7942f, -1.1326f, -1.686f, -0.0734f, 0.9499f, 0.2508f, 1.3307f, -0.966f, -1.3506f, -1.4267f, 1.1793f, -0.3751f, 0.7723f, -0.2359f, -0.2686f, 1.2551f, -0.6165f, -1.1625f, 0.5438f, -1.5241f, -1.8666f, -0.004f, 1.622f, -0.9495f, -0.8307f, -2.1322f, 0.3769f, 0.9336f, 0.2458f, 0.1653f, -0.4835f, -0.7139f, -0.4338f, 0.1007f, 1.4633f};
        float[] val = RandomUtils.val(2 * 3 * 5 * 5, 1.0f);
        float[] one = MatrixUtils.one(3);
        Tensor tensor = new Tensor(2, 3, 5, 5, fArr, true);
        Tensor tensor2 = new Tensor(1, 1, 1, 3, one, true);
        Tensor tensor3 = new Tensor(1, 1, 1, 3, true);
        Tensor tensor4 = new Tensor(2, 3, 5, 5, val, true);
        Tensor tensor5 = new Tensor(2, 3, 5, 5, true);
        Tensor tensor6 = new Tensor(2, 3, 5, 5, true);
        Tensor tensor7 = new Tensor(1, 1, 1, 3, true);
        Tensor tensor8 = new Tensor(1, 1, 1, 3, true);
        BNKernel2 bNKernel2 = new BNKernel2(BNType.conv_bn, 3, 5, 5);
        for (int i = 0; i < 1; i++) {
            bNKernel2.forward(RunModel.TRAIN, tensor2, tensor3, tensor, tensor5);
            JCudaDriver.cuCtxSynchronize();
            bNKernel2.backward(tensor, tensor4, tensor6, tensor2, tensor7, tensor8);
        }
        tensor5.syncHost();
        tensor6.syncHost();
        PrintUtils.printImage(tensor.data);
        System.out.println("");
        System.out.println("=======output==============");
        PrintUtils.printImage(tensor5.data);
        System.out.println("");
        System.out.println("=======diff==============");
        PrintUtils.printImage(tensor6.data);
        bNKernel2.foward_cpu(MatrixUtils.transform(fArr, 2, 3, 5, 5), new float[2][3][5][5], MatrixUtils.transform(val, 2, 3, 5, 5), new float[2][3][5][5], tensor2.data, tensor3.data, new float[3], new float[3], 1);
    }

    public static void test1d() {
        float[] fArr = {56.773f, -7.231f, 39.634f, 24.728f, -17.959f, 55.251f, -52.316f, -36.322f, -29.619f, 55.24f, 26.773f, -1.231f, 19.634f, 4.728f, 7.958f, -65.251f, 52.316f, -36.322f, -23.619f, -5.247f};
        float[] val = RandomUtils.val(2 * 1 * 1 * 10, 1.0f);
        float[] one = MatrixUtils.one(10);
        Tensor tensor = new Tensor(2, 1, 1, 10, fArr, true);
        Tensor tensor2 = new Tensor(1, 1, 1, 10, one, true);
        Tensor tensor3 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor4 = new Tensor(2, 1, 1, 10, val, true);
        Tensor tensor5 = new Tensor(2, 1, 1, 10, true);
        Tensor tensor6 = new Tensor(2, 1, 1, 10, true);
        Tensor tensor7 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor8 = new Tensor(1, 1, 1, 10, true);
        BNKernel2 bNKernel2 = new BNKernel2(BNType.fully_bn, 1, 1, 10);
        for (int i = 0; i < 1; i++) {
            bNKernel2.forward(RunModel.TRAIN, tensor2, tensor3, tensor, tensor5);
            bNKernel2.backward(tensor, tensor4, tensor6, tensor2, tensor7, tensor8);
        }
        tensor5.syncHost();
        tensor6.syncHost();
        PrintUtils.printImage(tensor5.data);
        System.out.println("");
        System.out.println("=======diff==============");
        PrintUtils.printImage(tensor6.data);
        float[][][][] transform = MatrixUtils.transform(fArr, 2, 1, 1, 10);
        float[][][][] transform2 = MatrixUtils.transform(val, 2, 1, 1, 10);
        float[][][][] fArr2 = new float[2][1][1][10];
        float[][][][] fArr3 = new float[2][1][1][10];
        bNKernel2.foward_cpu(transform, fArr2, transform2, fArr3, tensor2.data, tensor3.data, new float[10], new float[10], 0);
        PrintUtils.printImage(fArr2);
        System.out.println("=====================");
        PrintUtils.printImage(fArr3);
    }

    public void foward_cpu(float[][][][] fArr, float[][][][] fArr2, float[][][][] fArr3, float[][][][] fArr4, float[] fArr5, float[] fArr6, float[] fArr7, float[] fArr8, int i) {
        float[] fArr9 = new float[this.C];
        float[] fArr10 = new float[this.C];
        MatrixOperation.meanV2(fArr, fArr9, i);
        MatrixOperation.varV2(fArr, fArr9, fArr10, i);
        float[][][][] culOutput = culOutput(fArr, fArr2, fArr9, fArr10, fArr5, fArr6, i);
        float[][][][] fArr11 = new float[this.N][this.C][this.H][this.W];
        dxhat(fArr11, fArr5, fArr3);
        var(fArr, fArr9, new float[this.C]);
        float[] fArr12 = new float[this.C];
        for (int i2 = 0; i2 < this.C; i2++) {
            fArr12[i2] = (float) Math.sqrt(r0[i2] + this.eta);
        }
        System.out.println("dx3:");
        float[][][][] fArr13 = new float[this.N][this.C][this.H][this.W];
        dx3(fArr13, fArr12, culOutput, fArr11);
        PrintUtils.printImage(MatrixUtils.transform(fArr13));
    }

    private void computeDelta_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, float[][][][] fArr6, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                fArr4[i2] = 0.0f;
                fArr5[i2] = 0.0f;
                for (int i3 = 0; i3 < this.N; i3++) {
                    int i4 = i2;
                    fArr4[i4] = fArr4[i4] + (fArr[i3][0][0][i2] * fArr2[i3][0][0][i2]);
                    int i5 = i2;
                    fArr5[i5] = fArr5[i5] + fArr[i3][0][0][i2];
                    fArr6[i3][0][0][i2] = fArr[i3][0][0][i2] * fArr3[i2];
                }
            }
            return;
        }
        for (int i6 = 0; i6 < this.C; i6++) {
            fArr4[i6] = 0.0f;
            fArr5[i6] = 0.0f;
            for (int i7 = 0; i7 < this.N; i7++) {
                for (int i8 = 0; i8 < this.H; i8++) {
                    for (int i9 = 0; i9 < this.W; i9++) {
                        int i10 = i6;
                        fArr4[i10] = fArr4[i10] + (fArr[i7][i6][i8][i9] * fArr2[i7][i6][i8][i9]);
                        int i11 = i6;
                        fArr5[i11] = fArr5[i11] + fArr[i7][i6][i8][i9];
                        fArr6[i7][i6][i8][i9] = fArr[i7][i6][i8][i9] * fArr3[i6];
                    }
                }
            }
        }
    }

    private void meanDzSum_cpu(float[] fArr, float[] fArr2, float[] fArr3, float[] fArr4, float[][][][] fArr5, float[][][][] fArr6, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                float f = 0.0f;
                float f2 = 0.0f;
                for (int i3 = 0; i3 < this.N; i3++) {
                    f += (fArr5[i3][0][0][i2] - fArr2[i2]) * fArr6[i3][0][0][i2];
                    f2 += ((-1.0f) * fArr6[i3][0][0][i2]) / ((float) Math.sqrt(fArr3[i2] + this.eta));
                }
                fArr[i2] = (float) (f * (-0.5f) * Math.pow(fArr3[i2] + this.eta, -1.5d));
                fArr4[i2] = f2;
            }
            return;
        }
        for (int i4 = 0; i4 < this.C; i4++) {
            float f3 = 0.0f;
            float f4 = 0.0f;
            for (int i5 = 0; i5 < this.N; i5++) {
                for (int i6 = 0; i6 < this.H; i6++) {
                    for (int i7 = 0; i7 < this.W; i7++) {
                        f3 += (fArr5[i5][i4][i6][i7] - fArr2[i4]) * fArr6[i5][i4][i6][i7];
                        f4 += ((-1.0f) * fArr6[i5][i4][i6][i7]) / ((float) Math.sqrt(fArr3[i4] + this.eta));
                    }
                }
            }
            fArr[i4] = (float) (f3 * (-0.5f) * Math.pow(fArr3[i4] + this.eta, -1.5d));
            fArr4[i4] = f4;
        }
    }

    private void computeDiff_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, float[] fArr6, int i) {
        if (i != 1) {
            float f = 1.0f / this.N;
            for (int i2 = 0; i2 < this.N; i2++) {
                for (int i3 = 0; i3 < this.W; i3++) {
                    fArr[i2][0][0][i3] = (fArr[i2][0][0][i3] / ((float) Math.sqrt(fArr4[i3] + this.eta))) + (2.0f * fArr6[i3] * (fArr2[i2][0][0][i3] - fArr3[i3]) * f) + (fArr5[i3] * f);
                }
            }
            return;
        }
        float f2 = 1.0f / ((this.N * this.H) * this.W);
        for (int i4 = 0; i4 < this.N; i4++) {
            for (int i5 = 0; i5 < this.C; i5++) {
                for (int i6 = 0; i6 < this.H; i6++) {
                    for (int i7 = 0; i7 < this.W; i7++) {
                        fArr[i4][i5][i6][i7] = (fArr[i4][i5][i6][i7] / ((float) Math.sqrt(fArr4[i5] + this.eta))) + (2.0f * fArr6[i5] * (fArr2[i4][i5][i6][i7] - fArr3[i5]) * f2) + (fArr5[i5] * f2);
                    }
                }
            }
        }
    }

    private float[][][][] culOutput(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, float[] fArr6, int i) {
        int length = fArr.length;
        int length2 = fArr[0].length;
        int length3 = fArr[0][0].length;
        int length4 = fArr[0][0][0].length;
        System.out.println(length + ":" + length2 + ":" + length3 + ":" + length4);
        float[][][][] fArr7 = new float[length][length2][length3][length4];
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < length2; i3++) {
                for (int i4 = 0; i4 < length3; i4++) {
                    for (int i5 = 0; i5 < length4; i5++) {
                        if (i == 0) {
                            fArr7[i2][i3][i4][i5] = (fArr[i2][i3][i4][i5] - fArr3[i5]) / ((float) Math.sqrt(fArr4[i5] + this.eta));
                            fArr2[i2][i3][i4][i5] = (fArr7[i2][i3][i4][i5] * fArr5[i5]) + fArr6[i5];
                        } else {
                            fArr7[i2][i3][i4][i5] = (fArr[i2][i3][i4][i5] - fArr3[i3]) / ((float) Math.sqrt(fArr4[i3] + this.eta));
                            fArr2[i2][i3][i4][i5] = (fArr7[i2][i3][i4][i5] * fArr5[i3]) + fArr6[i3];
                        }
                    }
                }
            }
        }
        return fArr7;
    }

    public float gradientCheck(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6, Tensor tensor7, float f) {
        forward(RunModel.TRAIN, tensor3, tensor4, tensor, tensor5);
        forward(RunModel.TRAIN, tensor3, tensor4, tensor2, tensor6);
        return CheckArrayUtils.check(tensor7.syncHost(), MatrixOperation.division(MatrixOperation.subtraction(tensor5.syncHost(), tensor6.syncHost()), 2.0f * f));
    }

    public void var(float[][][][] fArr, float[] fArr2, float[] fArr3) {
        int i = this.N * this.H * this.W;
        for (int i2 = 0; i2 < this.C; i2++) {
            fArr3[i2] = 0.0f;
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < this.H; i4++) {
                    for (int i5 = 0; i5 < this.W; i5++) {
                        int i6 = i2;
                        fArr3[i6] = fArr3[i6] + ((fArr[i3][i2][i4][i5] - fArr2[i2]) * (fArr[i3][i2][i4][i5] - fArr2[i2]));
                    }
                }
            }
            fArr3[i2] = fArr3[i2] / i;
        }
    }

    public void std(float[][][][] fArr, float[] fArr2, float[] fArr3) {
        int i = this.N * this.H * this.W;
        for (int i2 = 0; i2 < this.C; i2++) {
            fArr3[i2] = 0.0f;
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < this.H; i4++) {
                    for (int i5 = 0; i5 < this.W; i5++) {
                        int i6 = i2;
                        fArr3[i6] = fArr3[i6] + ((fArr[i3][i2][i4][i5] - fArr2[i2]) * (fArr[i3][i2][i4][i5] - fArr2[i2]));
                    }
                }
            }
            fArr3[i2] = (float) Math.sqrt((fArr3[i2] + this.eta) / i);
        }
    }

    public void dxhat(float[][][][] fArr, float[] fArr2, float[][][][] fArr3) {
        for (int i = 0; i < this.C; i++) {
            for (int i2 = 0; i2 < this.N; i2++) {
                for (int i3 = 0; i3 < this.H; i3++) {
                    for (int i4 = 0; i4 < this.W; i4++) {
                        fArr[i2][i][i3][i4] = fArr3[i2][i][i3][i4] * fArr2[i];
                    }
                }
            }
        }
    }

    public void dx2(float[][][][] fArr, float[] fArr2, float[][][][] fArr3, float[][][][] fArr4) {
        float[] fArr5 = new float[this.C];
        float[] fArr6 = new float[this.C];
        int i = this.N * this.H * this.W;
        for (int i2 = 0; i2 < this.C; i2++) {
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < this.H; i4++) {
                    for (int i5 = 0; i5 < this.W; i5++) {
                        int i6 = i2;
                        fArr6[i6] = fArr6[i6] + (fArr4[i3][i2][i4][i5] * fArr3[i3][i2][i4][i5]);
                        int i7 = i2;
                        fArr5[i7] = fArr5[i7] + fArr4[i3][i2][i4][i5];
                    }
                }
            }
        }
        for (int i8 = 0; i8 < this.C; i8++) {
            for (int i9 = 0; i9 < this.N; i9++) {
                for (int i10 = 0; i10 < this.H; i10++) {
                    for (int i11 = 0; i11 < this.W; i11++) {
                        fArr[i9][i8][i10][i11] = (((i * fArr4[i9][i8][i10][i11]) - fArr5[i8]) - (fArr3[i9][i8][i10][i11] * fArr6[i8])) / (i * fArr2[i8]);
                    }
                }
            }
        }
    }

    public void dx3(float[][][][] fArr, float[] fArr2, float[][][][] fArr3, float[][][][] fArr4) {
        float[] fArr5 = new float[this.C];
        float[] fArr6 = new float[this.C];
        float[] fArr7 = new float[this.N * this.C * this.H * this.W];
        showDM((Pointer) this.d_z, fArr7);
        int i = this.N * this.H * this.W;
        float[][][][] transform = MatrixUtils.transform(fArr7, this.N, this.C, this.H, this.W);
        for (int i2 = 0; i2 < this.C; i2++) {
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < this.H; i4++) {
                    for (int i5 = 0; i5 < this.W; i5++) {
                        int i6 = i2;
                        fArr5[i6] = fArr5[i6] + fArr4[i3][i2][i4][i5];
                        int i7 = i2;
                        fArr6[i7] = fArr6[i7] + (fArr4[i3][i2][i4][i5] * transform[i3][i2][i4][i5]);
                    }
                }
            }
        }
        System.out.println(JsonUtils.toJson(MatrixUtils.transform(fArr3)));
        System.out.println(CheckArrayUtils.check(fArr7, MatrixUtils.transform(fArr3)));
        for (int i8 = 0; i8 < this.C; i8++) {
            float f = fArr5[i8] / i;
            float f2 = fArr6[i8] / i;
            for (int i9 = 0; i9 < this.N; i9++) {
                for (int i10 = 0; i10 < this.H; i10++) {
                    for (int i11 = 0; i11 < this.W; i11++) {
                        fArr[i9][i8][i10][i11] = (1.0f / fArr2[i8]) * ((fArr4[i9][i8][i10][i11] - f) - (fArr3[i9][i8][i10][i11] * f2));
                    }
                }
            }
        }
    }
}
