package com.omega.engine.nn.layer.normalization.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.CheckArrayUtils;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.PrintUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.nn.layer.gpu.BNBaseKernel;
import com.omega.engine.nn.layer.normalization.BNType;
import com.omega.engine.nn.network.RunModel;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/nn/layer/normalization/gpu/BNKernel.class */
public class BNKernel extends BNBaseKernel {
    private BNType bnType;
    private int C;
    private int H;
    private int W;
    private int meanNumber;
    private CUfunction mean_function;
    private CUfunction var_function;
    private CUfunction std_function;
    private CUfunction mwa_function;
    private CUfunction culOutput_function;
    private CUfunction fast_mean_function;
    private CUfunction fast_var_function;
    private CUfunction normalize_function;
    private CUfunction computeDiff_function;
    private CUfunction computeDelta_function;
    private CUfunction computeDelta_full_function;
    private CUfunction meanDzSum_function;
    private CUfunction dgama_function;
    private CUfunction dbeta_function;
    private CUfunction dxhat_function;
    private CUfunction full_dmean_function;
    private CUfunction full_dmean_ov_function;
    private CUfunction full_dvar_function;
    private CUfunction fast_dmean_function;
    private CUfunction fast_dmean_ov_function;
    private CUfunction fast_dvar_function;
    private CUfunction dx_function;
    private CUfunction dx_full_function;
    private CUfunction computeDParams_function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private float eta = 1.0E-5f;
    private float momentum = 0.9f;
    private CUdeviceptr d_z;
    private CUdeviceptr d_mean;
    private CUdeviceptr d_var;
    private CUdeviceptr d_runingMean;
    private CUdeviceptr d_runingVar;
    private CUdeviceptr d_dmean;
    private CUdeviceptr d_dvar;
    private Pointer meanParameters;
    private Pointer varParameters;
    private Pointer fastMeanParameters;
    private Pointer fastVarParameters;
    private Pointer normalizeParameters;
    private Pointer normalize_test_Parameters;
    private Pointer mwaParameters;
    private Pointer dgamaParameters;
    private Pointer dbetaParameters;
    private Pointer computeDelta_full_Parameters;
    private Pointer computeDParams_Parameters;
    private Pointer dxhatParameters;
    private Pointer fullDmeanParameters;
    private Pointer fullDMeanOVParameters;
    private Pointer fullDvarParameters;
    private Pointer fastDmeanParameters;
    private Pointer fastDvarParameters;
    private Pointer fastDMeanOVParameters;
    private Pointer dxParameters;
    private Pointer dx_fullParameters;

    public BNKernel(BNType bNType, int i, int i2, int i3) {
        this.meanNumber = 0;
        this.bnType = bNType;
        this.C = i;
        this.H = i2;
        this.W = i3;
        if (this.bnType == BNType.fully_bn) {
            this.meanNumber = i3;
        } else {
            this.meanNumber = i;
        }
        init();
    }

    public void initFunction() {
        try {
            if (this.computeDParams_function == null) {
                this.computeDParams_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "computeDParams");
            }
            if (this.mean_function == null) {
                this.mean_function = CUDAModules.getLocalFunctionByModule("MathKernel.cu", "mean_full");
            }
            if (this.fast_mean_function == null) {
                this.fast_mean_function = CUDAModules.getLocalFunctionByModule("MathKernel.cu", "fast_mean_kernel");
            }
            if (this.var_function == null) {
                this.var_function = CUDAModules.getLocalFunctionByModule("MathKernel.cu", "var_full");
            }
            if (this.fast_var_function == null) {
                this.fast_var_function = CUDAModules.getLocalFunctionByModule("MathKernel.cu", "fast_variance_kernel");
            }
            if (this.normalize_function == null) {
                this.normalize_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "normalize_kernel");
            }
            if (this.std_function == null) {
                this.std_function = CUDAModules.getLocalFunctionByModule("MathKernel.cu", "std_fn");
            }
            if (this.mwa_function == null) {
                this.mwa_function = CUDAModules.getLocalFunctionByModule("MathKernel.cu", "mwa");
            }
            if (this.culOutput_function == null) {
                this.culOutput_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "culOutput_cov");
            }
            if (this.computeDelta_function == null) {
                this.computeDelta_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "computeDelta");
            }
            if (this.computeDelta_full_function == null) {
                this.computeDelta_full_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "computeDelta_full");
            }
            if (this.meanDzSum_function == null) {
                this.meanDzSum_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "meanDzSum");
            }
            if (this.computeDiff_function == null) {
                this.computeDiff_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "computeDiff");
            }
            if (this.dgama_function == null) {
                this.dgama_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "dgama_kernel");
            }
            if (this.dbeta_function == null) {
                this.dbeta_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "dbeta_kernel");
            }
            if (this.dxhat_function == null) {
                this.dxhat_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "dxhat_kernel2");
            }
            if (this.full_dmean_function == null) {
                this.full_dmean_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "full_mean_delta_kernel");
            }
            if (this.full_dmean_ov_function == null) {
                this.full_dmean_ov_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "full_mean_delta_ov_kernel");
            }
            if (this.fast_dmean_function == null) {
                this.fast_dmean_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "fast_mean_delta_kernel");
            }
            if (this.fast_dmean_ov_function == null) {
                this.fast_dmean_ov_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "fast_mean_delta_ov_kernel");
            }
            if (this.full_dvar_function == null) {
                this.full_dvar_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "full_var_delta_kernel");
            }
            if (this.fast_dvar_function == null) {
                this.fast_dvar_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "fast_variance_delta_kernel");
            }
            if (this.dx_function == null) {
                this.dx_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "dx_kernel");
            }
            if (this.dx_full_function == null) {
                this.dx_full_function = CUDAModules.getLocalFunctionByModule("BNKernel.cu", "dx_kernel_full");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void initKernel() {
        this.d_mean = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_var = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_runingMean = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_runingVar = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_dmean = CUDAMemoryManager.getDevice(this.meanNumber);
        this.d_dvar = CUDAMemoryManager.getDevice(this.meanNumber);
    }

    public void init() {
        initFunction();
        initKernel();
    }

    public void initForward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        if (tensor.number != this.N) {
            this.N = tensor.number;
            if (this.bnType == BNType.fully_bn) {
                this.meanParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W})});
                this.varParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W})});
            } else {
                this.fastMeanParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_mean})});
                this.fastVarParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_var})});
            }
            this.mwaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_runingMean}), Pointer.to(new NativePointerObject[]{this.d_runingVar}), Pointer.to(new int[]{this.meanNumber}), Pointer.to(new float[]{this.momentum})});
            int i = 1;
            if (this.bnType == BNType.conv_bn) {
                i = this.H * this.W;
            }
            this.d_z = CUDAMemoryManager.getDevice(this.N * this.C * this.H * this.W);
            this.normalizeParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.meanNumber}), Pointer.to(new int[]{i}), Pointer.to(new float[]{this.eta})});
            this.normalize_test_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_runingMean}), Pointer.to(new NativePointerObject[]{this.d_runingVar}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.meanNumber}), Pointer.to(new int[]{i}), Pointer.to(new float[]{this.eta})});
        }
    }

    public void initBackward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        if (this.dgamaParameters == null) {
            if (this.bnType == BNType.fully_bn) {
                this.dgamaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()})});
                this.dbetaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor6.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W})});
                this.computeDelta_full_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor6.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W})});
                this.dx_fullParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_dmean}), Pointer.to(new NativePointerObject[]{this.d_dvar}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()})});
                this.fullDmeanParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{this.d_dmean})});
                this.fullDvarParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{this.d_dvar})});
                this.fullDMeanOVParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_dvar}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.W}), Pointer.to(new NativePointerObject[]{this.d_dmean})});
            } else {
                this.dxParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_dmean}), Pointer.to(new NativePointerObject[]{this.d_dvar}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()})});
                this.dgamaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()})});
                this.dbetaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor6.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W})});
                this.fastDmeanParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_dmean})});
                this.fastDMeanOVParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_dvar}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_dmean})});
                this.fastDvarParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_dvar})});
            }
            int i = 1;
            if (this.bnType == BNType.conv_bn) {
                i = this.H * this.W;
            }
            this.dxhatParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new int[]{this.meanNumber}), Pointer.to(new int[]{i})});
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void forward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        initForward(runModel, tensor3, tensor, tensor2, tensor4);
        if (runModel != RunModel.TRAIN) {
            normalize(tensor3, tensor, tensor2, tensor4);
            return;
        }
        if (this.bnType == BNType.fully_bn) {
            mean();
            var();
        } else {
            fast_mean();
            fast_var();
        }
        mwa();
        normalize_train(tensor3, tensor, tensor2, tensor4);
    }

    public void mean() {
        try {
            JCudaDriver.cuLaunchKernel(this.mean_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.meanParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void var() {
        try {
            JCudaDriver.cuLaunchKernel(this.var_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.varParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fast_mean() {
        try {
            JCudaDriver.cuLaunchKernel(this.fast_mean_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastMeanParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fast_var() {
        try {
            JCudaDriver.cuLaunchKernel(this.fast_var_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastVarParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void normalize(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            JCudaDriver.cuLaunchKernel(this.normalize_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.normalize_test_Parameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void normalize_train(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            JCudaDriver.cuLaunchKernel(this.normalize_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.normalizeParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void mwa() {
        try {
            JCudaDriver.cuLaunchKernel(this.mwa_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.mwaParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        initBackward(tensor, tensor2, tensor3, tensor4, tensor5, tensor6);
        if (this.bnType == BNType.fully_bn) {
            computeDgama();
            computeDbeta();
        } else {
            computeDgama();
            computeDbeta();
        }
        computeDxhat();
        if (this.bnType == BNType.fully_bn) {
            computeFullDvar();
            computeFullDmean();
        } else {
            computeDvar();
            computeDmean();
        }
        if (this.bnType == BNType.fully_bn) {
            computeDx_full();
        } else {
            computeDx();
        }
    }

    private void computeDParams(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        this.computeDParams_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H}), Pointer.to(new int[]{this.W})});
        JCudaDriver.cuLaunchKernel(this.computeDParams_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.computeDParams_Parameters, (Pointer) null);
    }

    private void computeDgama() {
        JCudaDriver.cuLaunchKernel(this.dgama_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dgamaParameters, (Pointer) null);
    }

    private void computeDbeta() {
        JCudaDriver.cuLaunchKernel(this.dbeta_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dbetaParameters, (Pointer) null);
    }

    private void computeDelta_full() {
        JCudaDriver.cuLaunchKernel(this.computeDelta_full_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.computeDelta_full_Parameters, (Pointer) null);
    }

    private void computeDxhat() {
        JCudaDriver.cuLaunchKernel(this.dxhat_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dxhatParameters, (Pointer) null);
    }

    private void computeDmean() {
        JCudaDriver.cuLaunchKernel(this.fast_dmean_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastDmeanParameters, (Pointer) null);
    }

    private void computeDmeanOV() {
        JCudaDriver.cuLaunchKernel(this.fast_dmean_ov_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastDMeanOVParameters, (Pointer) null);
    }

    private void computeFullDmean() {
        JCudaDriver.cuLaunchKernel(this.full_dmean_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fullDmeanParameters, (Pointer) null);
    }

    private void computeFullDmeanOV() {
        JCudaDriver.cuLaunchKernel(this.full_dmean_ov_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fullDMeanOVParameters, (Pointer) null);
    }

    private void computeDvar() {
        JCudaDriver.cuLaunchKernel(this.fast_dvar_function, this.meanNumber, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastDvarParameters, (Pointer) null);
    }

    private void computeFullDvar() {
        JCudaDriver.cuLaunchKernel(this.full_dvar_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fullDvarParameters, (Pointer) null);
    }

    private void computeDx() {
        JCudaDriver.cuLaunchKernel(this.dx_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dxParameters, (Pointer) null);
    }

    private void computeDx_full() {
        JCudaDriver.cuLaunchKernel(this.dx_full_function, CAFFE_GET_BLOCKS(this.meanNumber), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dx_fullParameters, (Pointer) null);
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public void showDM(String str, CUdeviceptr cUdeviceptr, float[] fArr) {
        JCudaDriver.cuMemcpyDtoH(Pointer.to(fArr), cUdeviceptr, fArr.length * 4);
        System.out.println(str + ":" + JsonUtils.toJson(fArr));
    }

    public static void main(String[] strArr) {
        CUDAModules.initContext();
        test2d();
    }

    public static void test2d() {
        float[] fArr = {0.9827f, 0.5268f, 0.4057f, 0.2853f, 0.1708f, 0.4791f, 0.5626f, 0.129f, 0.954f, 0.7471f, 0.5806f, 0.8789f, 0.9766f, 0.8142f, 0.9557f, 0.2814f, 0.7667f, 0.5963f, 0.0016f, 0.5944f, 0.4617f, 0.0975f, 0.3558f, 0.3318f, 0.5196f, 0.7558f, 0.7438f, 0.4061f, 0.2737f, 0.1826f, 0.76f, 0.3608f, 0.3924f, 0.2537f, 0.7536f, 0.798f, 0.5246f, 0.6428f, 0.0571f, 0.9973f, 0.7106f, 0.5854f, 0.3122f, 0.2741f, 0.2868f, 0.4628f, 0.2696f, 0.0436f, 0.1222f, 0.4933f, 0.5372f, 0.4992f, 0.2837f, 0.8462f, 0.2095f, 0.1916f, 0.183f, 0.1934f, 0.8305f, 0.0776f, 0.9014f, 0.1835f, 0.7673f, 0.0999f, 0.5783f, 0.7816f, 0.2961f, 0.923f, 0.3454f, 0.603f, 0.4821f, 0.0113f, 0.9629f, 0.8698f, 0.844f, 0.9763f, 0.7661f, 0.2085f, 0.4248f, 0.7407f, 0.5092f, 0.5272f, 0.8521f, 0.1649f, 0.9759f, 0.9084f, 0.3206f, 0.3061f, 0.9648f, 0.3377f, 0.6753f, 0.6662f, 0.457f, 0.9556f, 0.0918f, 0.8788f, 0.6432f, 0.4928f, 0.8778f, 0.5665f, 0.7979f, 0.5639f, 0.597f, 0.4987f, 0.1227f, 0.4963f, 0.6865f, 0.5728f, 0.1927f, 0.1199f, 0.5015f, 0.0221f, 0.0826f, 0.0077f, 0.0568f, 0.7569f, 0.7684f, 0.1536f, 0.4406f, 0.2919f, 0.3006f, 0.9501f, 0.1994f, 0.3314f, 0.5612f, 0.3303f, 0.8773f, 0.3262f, 0.1926f, 0.8667f, 0.336f, 0.5357f, 0.3332f, 0.2044f, 0.5538f, 0.0607f, 0.2203f, 0.7994f, 0.6357f, 0.6469f, 0.8163f, 0.7764f, 0.6821f, 0.6798f, 0.0553f, 0.0609f, 0.2305f, 0.7183f, 0.8135f, 0.7688f};
        float[] val = RandomUtils.val(2 * 3 * 5 * 5, 1.0f);
        float[] one = MatrixUtils.one(3);
        Tensor tensor = new Tensor(2, 3, 5, 5, fArr, true);
        Tensor tensor2 = new Tensor(1, 1, 1, 3, one, true);
        Tensor tensor3 = new Tensor(1, 1, 1, 3, true);
        Tensor tensor4 = new Tensor(2, 3, 5, 5, val, true);
        Tensor tensor5 = new Tensor(2, 3, 5, 5, true);
        Tensor tensor6 = new Tensor(2, 3, 5, 5, true);
        Tensor tensor7 = new Tensor(1, 1, 1, 3, true);
        Tensor tensor8 = new Tensor(1, 1, 1, 3, true);
        BNKernel bNKernel = new BNKernel(BNType.conv_bn, 3, 5, 5);
        for (int i = 0; i < 1; i++) {
            bNKernel.forward(RunModel.TRAIN, tensor2, tensor3, tensor, tensor5);
            JCudaDriver.cuCtxSynchronize();
            bNKernel.backward(tensor, tensor4, tensor6, tensor2, tensor7, tensor8);
        }
        tensor5.syncHost();
        tensor6.syncHost();
        PrintUtils.printImage(tensor.data);
        System.out.println("");
        float[][][][] transform = MatrixUtils.transform(fArr, 2, 3, 5, 5);
        float[][][][] transform2 = MatrixUtils.transform(val, 2, 3, 5, 5);
        float[][][][] fArr2 = new float[2][3][5][5];
        float[][][][] fArr3 = new float[2][3][5][5];
        float[] fArr4 = new float[3];
        bNKernel.foward_cpu(transform, fArr2, transform2, fArr3, tensor2.data, tensor3.data, fArr4, new float[3], 1);
        System.out.println("=======output==============");
        PrintUtils.printImage(tensor5.data);
        System.out.println("");
        System.out.println("=======output-cpu==============");
        System.out.println(JsonUtils.toJson(MatrixUtils.transform(fArr2)));
        System.out.println("==========diff-cpu===========");
        System.out.println(JsonUtils.toJson(MatrixUtils.transform(fArr3)));
        System.out.println("=======diff==============");
        PrintUtils.printImage(tensor6.data);
        System.out.println("==========gd===========");
        PrintUtils.printImage(fArr4);
        System.out.println("");
        System.out.println("output-error:" + CheckArrayUtils.check(tensor5.data, MatrixUtils.transform(fArr2)));
        System.out.println("diff-error:" + CheckArrayUtils.check(tensor6.data, MatrixUtils.transform(fArr3)));
    }

    public static void test1d() {
        float[] fArr = {56.773f, -7.231f, 39.634f, 24.728f, -17.959f, 55.251f, -52.316f, -36.322f, -29.619f, 55.24f, 26.773f, -1.231f, 19.634f, 4.728f, 7.958f, -65.251f, 52.316f, -36.322f, -23.619f, -5.247f};
        float[] val = RandomUtils.val(2 * 1 * 1 * 10, 1.0f);
        float[] one = MatrixUtils.one(10);
        Tensor tensor = new Tensor(2, 1, 1, 10, fArr, true);
        Tensor tensor2 = new Tensor(1, 1, 1, 10, one, true);
        Tensor tensor3 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor4 = new Tensor(2, 1, 1, 10, val, true);
        Tensor tensor5 = new Tensor(2, 1, 1, 10, true);
        Tensor tensor6 = new Tensor(2, 1, 1, 10, true);
        Tensor tensor7 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor8 = new Tensor(1, 1, 1, 10, true);
        BNKernel bNKernel = new BNKernel(BNType.fully_bn, 1, 1, 10);
        for (int i = 0; i < 1; i++) {
            bNKernel.forward(RunModel.TRAIN, tensor2, tensor3, tensor, tensor5);
            bNKernel.backward(tensor, tensor4, tensor6, tensor2, tensor7, tensor8);
        }
        tensor5.syncHost();
        tensor6.syncHost();
        tensor7.syncHost();
        float[][][][] transform = MatrixUtils.transform(fArr, 2, 1, 1, 10);
        float[][][][] transform2 = MatrixUtils.transform(val, 2, 1, 1, 10);
        float[][][][] fArr2 = new float[2][1][1][10];
        float[][][][] fArr3 = new float[2][1][1][10];
        bNKernel.foward_cpu(transform, fArr2, transform2, fArr3, tensor2.data, tensor3.data, new float[10], new float[10], 0);
        System.out.println("output-gpu:");
        System.out.println(JsonUtils.toJson(tensor5.data));
        System.out.println("output-cpu:");
        System.out.println(JsonUtils.toJson(MatrixUtils.transform(fArr2)));
        System.out.println("=======diff-gpu==============");
        System.out.println(JsonUtils.toJson(tensor6.data));
        System.out.println("=======diff-cpu==============");
        System.out.println(JsonUtils.toJson(MatrixUtils.transform(fArr3)));
        System.out.println(JsonUtils.toJson(tensor7.data));
        System.out.println("output-error:" + CheckArrayUtils.check(tensor5.data, MatrixUtils.transform(fArr2)));
        System.out.println("diff-error:" + CheckArrayUtils.check(tensor6.data, MatrixUtils.transform(fArr3)));
    }

    public void foward_cpu(float[][][][] fArr, float[][][][] fArr2, float[][][][] fArr3, float[][][][] fArr4, float[] fArr5, float[] fArr6, float[] fArr7, float[] fArr8, int i) {
        int i2 = this.W;
        if (i == 1) {
            i2 = this.C;
        }
        float[] fArr9 = new float[i2];
        float[] fArr10 = new float[i2];
        MatrixOperation.meanV2(fArr, fArr9, i);
        MatrixOperation.varV2(fArr, fArr9, fArr10, i);
        backward_cpu(fArr, fArr9, fArr10, culOutput(fArr, fArr2, fArr9, fArr10, fArr5, fArr6, i), fArr3, fArr5, fArr4, fArr7, fArr8, i);
    }

    public void backward_cpu(float[][][][] fArr, float[] fArr2, float[] fArr3, float[][][][] fArr4, float[][][][] fArr5, float[] fArr6, float[][][][] fArr7, float[] fArr8, float[] fArr9, int i) {
        int i2 = this.W;
        if (i == 1) {
            i2 = this.C;
        }
        float[] fArr10 = new float[i2];
        dgamma_cpu(fArr5, fArr4, fArr10, i);
        float[] fArr11 = new float[i2];
        dgamma2_cpu(fArr5, fArr, fArr2, fArr3, fArr11, i);
        float[][][][] fArr12 = new float[this.N][this.C][this.H][this.W];
        dxhat_cpu(fArr12, fArr5, fArr6, i);
        float[] fArr13 = new float[i2];
        dvar_cpu(fArr13, fArr2, fArr3, fArr, fArr12, i);
        float[] fArr14 = new float[i2];
        dmean_cpu(fArr14, fArr2, fArr3, fArr13, fArr, fArr12, i);
        System.out.println("dvar-cpu:");
        PrintUtils.printImage(fArr13);
        System.out.println("");
        System.out.println("dmean-cpu:");
        PrintUtils.printImage(fArr14);
        System.out.println("");
        System.out.println("dgamma-cpu:");
        PrintUtils.printImage(fArr10);
        System.out.println("");
        System.out.println("dgamma2-cpu:");
        PrintUtils.printImage(fArr11);
        System.out.println("");
        dx_cpu(fArr7, fArr12, fArr2, fArr3, fArr14, fArr13, fArr, i);
        float[][][][] fArr15 = new float[this.N][this.C][this.H][this.W];
        dx2_cpu(fArr15, fArr5, fArr2, fArr3, fArr6, fArr, i);
        System.out.println("dx-cpu:");
        PrintUtils.printImage(fArr7);
        System.out.println("");
        System.out.println("dx-cpu2:");
        PrintUtils.printImage(fArr15);
        System.out.println("");
    }

    private void dgamma_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                fArr3[i2] = 0.0f;
                for (int i3 = 0; i3 < this.N; i3++) {
                    System.out.println(fArr2[i3][0][0][i2] + ":" + i3 + ":" + i2);
                    int i4 = i2;
                    fArr3[i4] = fArr3[i4] + (fArr[i3][0][0][i2] * fArr2[i3][0][0][i2]);
                }
            }
            return;
        }
        for (int i5 = 0; i5 < this.C; i5++) {
            fArr3[i5] = 0.0f;
            for (int i6 = 0; i6 < this.N; i6++) {
                for (int i7 = 0; i7 < this.H; i7++) {
                    for (int i8 = 0; i8 < this.W; i8++) {
                        int i9 = i5;
                        fArr3[i9] = fArr3[i9] + (fArr[i6][i5][i7][i8] * fArr2[i6][i5][i7][i8]);
                    }
                }
            }
        }
    }

    private void dgamma2_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                fArr5[i2] = 0.0f;
                for (int i3 = 0; i3 < this.N; i3++) {
                    int i4 = i2;
                    fArr5[i4] = fArr5[i4] + ((fArr2[i3][0][0][i2] - fArr3[i2]) * fArr[i3][0][0][i2]);
                }
                fArr5[i2] = (float) (fArr5[r1] / Math.sqrt(fArr4[i2] + this.eta));
            }
            return;
        }
        float[][][][] fArr6 = new float[this.N][this.C][this.H][this.W];
        for (int i5 = 0; i5 < this.C; i5++) {
            fArr5[i5] = 0.0f;
            for (int i6 = 0; i6 < this.N; i6++) {
                for (int i7 = 0; i7 < this.H; i7++) {
                    for (int i8 = 0; i8 < this.W; i8++) {
                        fArr6[i6][i5][i7][i8] = (fArr2[i6][i5][i7][i8] - fArr3[i5]) * fArr[i6][i5][i7][i8];
                        int i9 = i5;
                        fArr5[i9] = fArr5[i9] + ((fArr2[i6][i5][i7][i8] - fArr3[i5]) * fArr[i6][i5][i7][i8]);
                    }
                }
            }
            fArr5[i5] = (float) (fArr5[r1] * (1.0d / Math.sqrt(fArr4[i5] + this.eta)));
        }
        float f = 0.0f;
        for (int i10 = 0; i10 < this.N; i10++) {
            for (int i11 = 0; i11 < this.C; i11++) {
                for (int i12 = 0; i12 < this.H; i12++) {
                    for (int i13 = 0; i13 < this.W; i13++) {
                        f += fArr6[i10][i11][i12][i13];
                    }
                }
            }
        }
        System.out.println(f);
    }

    private void dxhat_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                for (int i3 = 0; i3 < this.N; i3++) {
                    fArr[i3][0][0][i2] = fArr2[i3][0][0][i2] * fArr3[i2];
                }
            }
            return;
        }
        for (int i4 = 0; i4 < this.C; i4++) {
            for (int i5 = 0; i5 < this.N; i5++) {
                for (int i6 = 0; i6 < this.H; i6++) {
                    for (int i7 = 0; i7 < this.W; i7++) {
                        fArr[i5][i4][i6][i7] = fArr2[i5][i4][i6][i7] * fArr3[i4];
                    }
                }
            }
        }
    }

    private void dvar_cpu(float[] fArr, float[] fArr2, float[] fArr3, float[][][][] fArr4, float[][][][] fArr5, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                float f = 0.0f;
                for (int i3 = 0; i3 < this.N; i3++) {
                    f += fArr5[i3][0][0][i2] * (fArr4[i3][0][0][i2] - fArr2[i2]) * (-0.5f);
                }
                fArr[i2] = (float) (f * Math.pow(fArr3[i2] + this.eta, -1.5d));
            }
            return;
        }
        for (int i4 = 0; i4 < this.C; i4++) {
            float f2 = 0.0f;
            for (int i5 = 0; i5 < this.N; i5++) {
                for (int i6 = 0; i6 < this.H; i6++) {
                    for (int i7 = 0; i7 < this.W; i7++) {
                        f2 += fArr5[i5][i4][i6][i7] * (fArr4[i5][i4][i6][i7] - fArr2[i4]) * (-0.5f);
                    }
                }
            }
            fArr[i4] = (float) (f2 * Math.pow(fArr3[i4] + this.eta, -1.5d));
        }
    }

    private void dmean_cpu(float[] fArr, float[] fArr2, float[] fArr3, float[] fArr4, float[][][][] fArr5, float[][][][] fArr6, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                fArr[i2] = 0.0f;
                float f = 0.0f;
                float f2 = 0.0f;
                for (int i3 = 0; i3 < this.N; i3++) {
                    f += (-1.0f) * fArr6[i3][0][0][i2];
                    f2 += (-2.0f) * (fArr5[i3][0][0][i2] - fArr2[i2]);
                }
                fArr[i2] = (float) ((f / Math.sqrt(fArr3[i2] + this.eta)) + ((fArr4[i2] * f2) / this.N));
            }
            return;
        }
        for (int i4 = 0; i4 < this.C; i4++) {
            fArr[i4] = 0.0f;
            float f3 = 0.0f;
            float f4 = 0.0f;
            for (int i5 = 0; i5 < this.N; i5++) {
                for (int i6 = 0; i6 < this.H; i6++) {
                    for (int i7 = 0; i7 < this.W; i7++) {
                        f3 = (float) (f3 + (((-1.0f) * fArr6[i5][i4][i6][i7]) / Math.sqrt(fArr3[i4] + this.eta)));
                        f4 += ((fArr4[i4] * (-2.0f)) / ((this.N * this.H) * this.W)) * (fArr5[i5][i4][i6][i7] - fArr2[i4]);
                    }
                }
            }
            fArr[i4] = f3 + f4;
        }
    }

    private void dx_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, float[] fArr6, float[][][][] fArr7, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                for (int i3 = 0; i3 < this.N; i3++) {
                    fArr[i3][0][0][i2] = (float) ((fArr2[i3][0][0][i2] / Math.sqrt(fArr4[i2] + this.eta)) + (((fArr6[i2] * 2.0f) * (fArr7[i3][0][0][i2] - fArr3[i2])) / this.N) + (fArr5[i2] / this.N));
                }
            }
            return;
        }
        for (int i4 = 0; i4 < this.C; i4++) {
            for (int i5 = 0; i5 < this.N; i5++) {
                for (int i6 = 0; i6 < this.H; i6++) {
                    for (int i7 = 0; i7 < this.W; i7++) {
                        fArr[i5][i4][i6][i7] = (float) ((fArr2[i5][i4][i6][i7] / Math.sqrt(fArr4[i4] + this.eta)) + (((fArr6[i4] * 2.0f) * (fArr7[i5][i4][i6][i7] - fArr3[i4])) / ((this.N * this.H) * this.W)) + (fArr5[i4] / ((this.N * this.H) * this.W)));
                    }
                }
            }
        }
    }

    private void dx2_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, float[][][][] fArr6, int i) {
        if (i != 1) {
            float[] fArr7 = new float[fArr3.length];
            float[] fArr8 = new float[fArr3.length];
            float[] fArr9 = new float[fArr3.length];
            float[] fArr10 = new float[fArr3.length];
            float[] fArr11 = new float[fArr3.length];
            float f = 1.0f / this.N;
            for (int i2 = 0; i2 < this.W; i2++) {
                for (int i3 = 0; i3 < this.N; i3++) {
                    int i4 = i2;
                    fArr7[i4] = fArr7[i4] + ((fArr6[i3][0][0][i2] - fArr3[i2]) * fArr2[i3][0][0][i2]);
                    int i5 = i2;
                    fArr8[i5] = fArr8[i5] + (fArr2[i3][0][0][i2] * f);
                }
            }
            for (int i6 = 0; i6 < this.W; i6++) {
                fArr11[i6] = 1.0f / ((float) Math.sqrt(fArr4[i6] + this.eta));
                fArr9[i6] = fArr7[i6] * f * fArr11[i6] * fArr11[i6];
                fArr10[i6] = fArr11[i6] * fArr5[i6];
            }
            for (int i7 = 0; i7 < this.N; i7++) {
                for (int i8 = 0; i8 < this.W; i8++) {
                    fArr[i7][0][0][i8] = ((fArr2[i7][0][0][i8] - fArr8[i8]) - ((fArr6[i7][0][0][i8] - fArr3[i8]) * fArr9[i8])) * fArr10[i8];
                }
            }
            return;
        }
        float[] fArr12 = new float[fArr3.length];
        float[] fArr13 = new float[fArr3.length];
        float[] fArr14 = new float[fArr3.length];
        float[] fArr15 = new float[fArr3.length];
        float[] fArr16 = new float[fArr3.length];
        float f2 = ((1.0f / this.N) / this.H) / this.W;
        for (int i9 = 0; i9 < this.C; i9++) {
            for (int i10 = 0; i10 < this.N; i10++) {
                for (int i11 = 0; i11 < this.H; i11++) {
                    for (int i12 = 0; i12 < this.W; i12++) {
                        int i13 = i9;
                        fArr12[i13] = fArr12[i13] + ((fArr6[i10][i9][i11][i12] - fArr3[i9]) * fArr2[i10][i9][i11][i12]);
                        int i14 = i9;
                        fArr13[i14] = fArr13[i14] + (fArr2[i10][i9][i11][i12] * f2);
                    }
                }
            }
        }
        for (int i15 = 0; i15 < this.C; i15++) {
            fArr16[i15] = 1.0f / ((float) Math.sqrt(fArr4[i15] + this.eta));
            fArr14[i15] = fArr12[i15] * f2 * fArr16[i15] * fArr16[i15];
            fArr15[i15] = fArr16[i15] * fArr5[i15];
        }
        for (int i16 = 0; i16 < this.N; i16++) {
            for (int i17 = 0; i17 < this.C; i17++) {
                for (int i18 = 0; i18 < this.H; i18++) {
                    for (int i19 = 0; i19 < this.W; i19++) {
                        fArr[i16][i17][i18][i19] = ((fArr2[i16][i17][i18][i19] - fArr13[i17]) - ((fArr6[i16][i17][i18][i19] - fArr3[i17]) * fArr14[i17])) * fArr15[i17];
                    }
                }
            }
        }
    }

    private void computeDelta_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, float[][][][] fArr6, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                fArr4[i2] = 0.0f;
                fArr5[i2] = 0.0f;
                for (int i3 = 0; i3 < this.N; i3++) {
                    int i4 = i2;
                    fArr4[i4] = fArr4[i4] + (fArr[i3][0][0][i2] * fArr2[i3][0][0][i2]);
                    int i5 = i2;
                    fArr5[i5] = fArr5[i5] + fArr[i3][0][0][i2];
                    fArr6[i3][0][0][i2] = fArr[i3][0][0][i2] * fArr3[i2];
                }
            }
            return;
        }
        for (int i6 = 0; i6 < this.C; i6++) {
            fArr4[i6] = 0.0f;
            fArr5[i6] = 0.0f;
            for (int i7 = 0; i7 < this.N; i7++) {
                for (int i8 = 0; i8 < this.H; i8++) {
                    for (int i9 = 0; i9 < this.W; i9++) {
                        int i10 = i6;
                        fArr4[i10] = fArr4[i10] + (fArr[i7][i6][i8][i9] * fArr2[i7][i6][i8][i9]);
                        int i11 = i6;
                        fArr5[i11] = fArr5[i11] + fArr[i7][i6][i8][i9];
                        fArr6[i7][i6][i8][i9] = fArr[i7][i6][i8][i9] * fArr3[i6];
                    }
                }
            }
        }
    }

    private void meanDzSum_cpu(float[] fArr, float[] fArr2, float[] fArr3, float[] fArr4, float[][][][] fArr5, float[][][][] fArr6, int i) {
        if (i != 1) {
            for (int i2 = 0; i2 < this.W; i2++) {
                float f = 0.0f;
                float f2 = 0.0f;
                float f3 = 0.0f;
                for (int i3 = 0; i3 < this.N; i3++) {
                    f += (fArr5[i3][0][0][i2] - fArr2[i2]) * fArr6[i3][0][0][i2];
                    f2 += ((-1.0f) * fArr6[i3][0][0][i2]) / ((float) Math.sqrt(fArr3[i2] + this.eta));
                    f3 += (-2.0f) * (fArr5[i3][0][0][i2] - fArr2[i2]);
                }
                fArr[i2] = (float) (f * (-0.5f) * Math.pow(fArr3[i2] + this.eta, -1.5d));
                fArr4[i2] = f2 + ((fArr[i2] * f3) / this.N);
            }
            return;
        }
        for (int i4 = 0; i4 < this.C; i4++) {
            float f4 = 0.0f;
            float f5 = 0.0f;
            float f6 = 0.0f;
            for (int i5 = 0; i5 < this.N; i5++) {
                for (int i6 = 0; i6 < this.H; i6++) {
                    for (int i7 = 0; i7 < this.W; i7++) {
                        f4 += (fArr5[i5][i4][i6][i7] - fArr2[i4]) * fArr6[i5][i4][i6][i7];
                        f5 += ((-1.0f) * fArr6[i5][i4][i6][i7]) / ((float) Math.sqrt(fArr3[i4] + this.eta));
                        f6 += (-2.0f) * (fArr5[i5][i4][i6][i7] - fArr2[i4]);
                    }
                }
            }
            fArr[i4] = (float) (f4 * (-0.5f) * Math.pow(fArr3[i4] + this.eta, -1.5d));
            fArr4[i4] = f5 + ((((fArr[i4] * f6) / this.N) / this.H) / this.W);
        }
    }

    private void computeDiff_cpu(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, float[] fArr6, int i) {
        if (i != 1) {
            float f = 1.0f / this.N;
            for (int i2 = 0; i2 < this.N; i2++) {
                for (int i3 = 0; i3 < this.W; i3++) {
                    fArr[i2][0][0][i3] = (fArr[i2][0][0][i3] / ((float) Math.sqrt(fArr4[i3] + this.eta))) + (2.0f * fArr6[i3] * (fArr2[i2][0][0][i3] - fArr3[i3]) * f) + (fArr5[i3] * f);
                }
            }
            return;
        }
        float f2 = 1.0f / ((this.N * this.H) * this.W);
        for (int i4 = 0; i4 < this.N; i4++) {
            for (int i5 = 0; i5 < this.C; i5++) {
                for (int i6 = 0; i6 < this.H; i6++) {
                    for (int i7 = 0; i7 < this.W; i7++) {
                        fArr[i4][i5][i6][i7] = (fArr[i4][i5][i6][i7] / ((float) Math.sqrt(fArr4[i5] + this.eta))) + (2.0f * fArr6[i5] * (fArr2[i4][i5][i6][i7] - fArr3[i5]) * f2) + (fArr5[i5] * f2);
                    }
                }
            }
        }
    }

    private float[][][][] culOutput(float[][][][] fArr, float[][][][] fArr2, float[] fArr3, float[] fArr4, float[] fArr5, float[] fArr6, int i) {
        int length = fArr.length;
        int length2 = fArr[0].length;
        int length3 = fArr[0][0].length;
        int length4 = fArr[0][0][0].length;
        System.out.println(length + ":" + length2 + ":" + length3 + ":" + length4);
        float[][][][] fArr7 = new float[length][length2][length3][length4];
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < length2; i3++) {
                for (int i4 = 0; i4 < length3; i4++) {
                    for (int i5 = 0; i5 < length4; i5++) {
                        if (i == 0) {
                            fArr7[i2][i3][i4][i5] = (fArr[i2][i3][i4][i5] - fArr3[i5]) / ((float) Math.sqrt(fArr4[i5] + this.eta));
                            fArr2[i2][i3][i4][i5] = (fArr7[i2][i3][i4][i5] * fArr5[i5]) + fArr6[i5];
                        } else {
                            fArr7[i2][i3][i4][i5] = (fArr[i2][i3][i4][i5] - fArr3[i3]) / ((float) Math.sqrt(fArr4[i3] + this.eta));
                            fArr2[i2][i3][i4][i5] = (fArr7[i2][i3][i4][i5] * fArr5[i3]) + fArr6[i3];
                        }
                    }
                }
            }
        }
        return fArr7;
    }

    public float gradientCheck(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6, Tensor tensor7, float f) {
        forward(RunModel.TRAIN, tensor3, tensor4, tensor, tensor5);
        forward(RunModel.TRAIN, tensor3, tensor4, tensor2, tensor6);
        return CheckArrayUtils.check(tensor7.syncHost(), MatrixOperation.division(MatrixOperation.subtraction(tensor5.syncHost(), tensor6.syncHost()), 2.0f * f));
    }

    public void var(float[][][][] fArr, float[] fArr2, float[] fArr3) {
        int i = this.N * this.H * this.W;
        for (int i2 = 0; i2 < this.C; i2++) {
            fArr3[i2] = 0.0f;
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < this.H; i4++) {
                    for (int i5 = 0; i5 < this.W; i5++) {
                        int i6 = i2;
                        fArr3[i6] = fArr3[i6] + ((fArr[i3][i2][i4][i5] - fArr2[i2]) * (fArr[i3][i2][i4][i5] - fArr2[i2]));
                    }
                }
            }
            fArr3[i2] = fArr3[i2] / i;
        }
    }

    public void std(float[][][][] fArr, float[] fArr2, float[] fArr3) {
        int i = this.N * this.H * this.W;
        for (int i2 = 0; i2 < this.C; i2++) {
            fArr3[i2] = 0.0f;
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < this.H; i4++) {
                    for (int i5 = 0; i5 < this.W; i5++) {
                        int i6 = i2;
                        fArr3[i6] = fArr3[i6] + ((fArr[i3][i2][i4][i5] - fArr2[i2]) * (fArr[i3][i2][i4][i5] - fArr2[i2]));
                    }
                }
            }
            fArr3[i2] = (float) Math.sqrt((fArr3[i2] + this.eta) / i);
        }
    }

    public void dxhat(float[][][][] fArr, float[] fArr2, float[][][][] fArr3) {
        for (int i = 0; i < this.C; i++) {
            for (int i2 = 0; i2 < this.N; i2++) {
                for (int i3 = 0; i3 < this.H; i3++) {
                    for (int i4 = 0; i4 < this.W; i4++) {
                        fArr[i2][i][i3][i4] = fArr3[i2][i][i3][i4] * fArr2[i];
                    }
                }
            }
        }
    }

    public void dx2(float[][][][] fArr, float[] fArr2, float[][][][] fArr3, float[][][][] fArr4) {
        float[] fArr5 = new float[this.C];
        float[] fArr6 = new float[this.C];
        int i = this.N * this.H * this.W;
        for (int i2 = 0; i2 < this.C; i2++) {
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < this.H; i4++) {
                    for (int i5 = 0; i5 < this.W; i5++) {
                        int i6 = i2;
                        fArr6[i6] = fArr6[i6] + (fArr4[i3][i2][i4][i5] * fArr3[i3][i2][i4][i5]);
                        int i7 = i2;
                        fArr5[i7] = fArr5[i7] + fArr4[i3][i2][i4][i5];
                    }
                }
            }
        }
        for (int i8 = 0; i8 < this.C; i8++) {
            for (int i9 = 0; i9 < this.N; i9++) {
                for (int i10 = 0; i10 < this.H; i10++) {
                    for (int i11 = 0; i11 < this.W; i11++) {
                        fArr[i9][i8][i10][i11] = ((((i * fArr4[i9][i8][i10][i11]) - fArr5[i8]) + ((fArr6[i8] * fArr5[i8]) / i)) - (fArr3[i9][i8][i10][i11] * fArr6[i8])) / (i * fArr2[i8]);
                    }
                }
            }
        }
    }

    public void dx3(float[][][][] fArr, float[] fArr2, float[][][][] fArr3, float[][][][] fArr4) {
        float[] fArr5 = new float[this.C];
        float[] fArr6 = new float[this.C];
        int i = this.N * this.H * this.W;
        for (int i2 = 0; i2 < this.C; i2++) {
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < this.H; i4++) {
                    for (int i5 = 0; i5 < this.W; i5++) {
                        int i6 = i2;
                        fArr6[i6] = fArr6[i6] + (fArr4[i3][i2][i4][i5] * fArr3[i3][i2][i4][i5]);
                        int i7 = i2;
                        fArr5[i7] = fArr5[i7] + fArr4[i3][i2][i4][i5];
                    }
                }
            }
        }
        for (int i8 = 0; i8 < this.C; i8++) {
            float f = fArr5[i8] / i;
            float f2 = fArr6[i8] / i;
            for (int i9 = 0; i9 < this.N; i9++) {
                for (int i10 = 0; i10 < this.H; i10++) {
                    for (int i11 = 0; i11 < this.W; i11++) {
                        fArr[i9][i8][i10][i11] = (1.0f / fArr2[i8]) * ((fArr4[i9][i8][i10][i11] - f) - (fArr3[i9][i8][i10][i11] * f2));
                    }
                }
            }
        }
    }

    public static float gradientCheck() {
        float[] fArr = {1.0f, 1.0f, 1.0f, 2.0f, -3.0f, 5.0f, 0.1f, -0.3f, 0.5f, 1.2f, -1.3f, 1.5f};
        float[] add = MatrixOperation.add(fArr, 1.0E-6f);
        float[] subtraction = MatrixOperation.subtraction(fArr, 1.0E-6f);
        float[] val = RandomUtils.val(4 * 1 * 1 * 3, 1.0f);
        float[] one = MatrixUtils.one(3);
        Tensor tensor = new Tensor(4, 1, 1, 3, fArr, true);
        Tensor tensor2 = new Tensor(4, 1, 1, 3, add, true);
        Tensor tensor3 = new Tensor(4, 1, 1, 3, subtraction, true);
        Tensor tensor4 = new Tensor(1, 1, 1, 3, one, true);
        Tensor tensor5 = new Tensor(1, 1, 1, 3, true);
        Tensor tensor6 = new Tensor(4, 1, 1, 3, val, true);
        Tensor tensor7 = new Tensor(4, 1, 1, 3, true);
        Tensor tensor8 = new Tensor(4, 1, 1, 3, true);
        Tensor tensor9 = new Tensor(4, 1, 1, 3, true);
        Tensor tensor10 = new Tensor(4, 1, 1, 3, true);
        Tensor tensor11 = new Tensor(1, 1, 1, 3, true);
        Tensor tensor12 = new Tensor(1, 1, 1, 3, true);
        BNKernel bNKernel = new BNKernel(BNType.fully_bn, 1, 1, 3);
        BNKernel bNKernel2 = new BNKernel(BNType.fully_bn, 1, 1, 3);
        bNKernel.forward(RunModel.TRAIN, tensor4, tensor5, tensor2, tensor8);
        bNKernel2.forward(RunModel.TRAIN, tensor4, tensor5, tensor3, tensor9);
        tensor8.syncHost();
        tensor9.syncHost();
        float[] division = MatrixOperation.division(MatrixOperation.subtraction(tensor8.data, tensor9.data), 2.0f * 1.0E-6f);
        System.out.println(JsonUtils.toJson(division));
        BNKernel bNKernel3 = new BNKernel(BNType.fully_bn, 1, 1, 3);
        bNKernel3.forward(RunModel.TRAIN, tensor4, tensor5, tensor, tensor7);
        bNKernel3.backward(tensor, tensor6, tensor10, tensor4, tensor11, tensor12);
        tensor10.syncHost();
        System.out.println(JsonUtils.toJson(tensor10.data));
        return CheckArrayUtils.check(tensor10.data, division);
    }
}
