package com.omega.engine.nn.layer.normalization.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.nn.layer.gpu.BNBaseKernel;
import com.omega.engine.nn.network.RunModel;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/nn/layer/normalization/gpu/BNKernel3.class */
public class BNKernel3 extends BNBaseKernel {
    private int C;
    private int H;
    private int W;
    private int spatial;
    private CUfunction mwa_function;
    private CUfunction fast_mean_function;
    private CUfunction fast_var_function;
    private CUfunction normalize_function;
    private CUfunction normalize_test_function;
    private CUfunction dgamma_function;
    private CUfunction dbeta_function;
    private CUfunction dbeta_full_function;
    private CUfunction dxhat_function;
    private CUfunction fast_dmean_function;
    private CUfunction fast_dvar_function;
    private CUfunction dx_function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private float eta = 1.0E-5f;
    private float momentum = 0.01f;
    private CUdeviceptr d_z;
    private CUdeviceptr d_mean;
    private CUdeviceptr d_var;
    private CUdeviceptr d_dmean;
    private CUdeviceptr d_dvar;
    private Pointer fastMeanParameters;
    private Pointer fastVarParameters;
    private Pointer normalizeParameters;
    private Pointer normalize_test_Parameters;
    private Pointer mwaParameters;
    private Pointer dgammaParameters;
    private Pointer dbetaParameters;
    private Pointer dbetaFullParameters;
    private Pointer dxhatParameters;
    private Pointer fastDmeanParameters;
    private Pointer fastDvarParameters;
    private Pointer dxParameters;

    public BNKernel3(int i, int i2, int i3, Tensor tensor, Tensor tensor2) {
        this.C = i;
        this.H = i2;
        this.W = i3;
        this.spatial = i2 * i3;
        this.runingMean = tensor;
        this.runingVar = tensor2;
        init();
    }

    public void initFunction() {
        try {
            if (this.fast_mean_function == null) {
                this.fast_mean_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "fast_mean_kernel");
            }
            if (this.fast_var_function == null) {
                this.fast_var_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "fast_variance_kernel");
            }
            if (this.normalize_function == null) {
                this.normalize_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "normalize_kernel");
            }
            if (this.normalize_test_function == null) {
                this.normalize_test_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "normalize_test_kernel");
            }
            if (this.mwa_function == null) {
                this.mwa_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "mwa_kernel");
            }
            if (this.dgamma_function == null) {
                this.dgamma_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "backward_scale_kernel");
            }
            if (this.dbeta_function == null) {
                this.dbeta_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "backward_bias_kernel");
            }
            if (this.dbeta_full_function == null) {
                this.dbeta_full_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "backward_bias_conn_kernel");
            }
            if (this.dxhat_function == null) {
                this.dxhat_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "scale_bias_kernel");
            }
            if (this.fast_dmean_function == null) {
                this.fast_dmean_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "fast_mean_delta_kernel");
            }
            if (this.fast_dvar_function == null) {
                this.fast_dvar_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "fast_variance_delta_kernel");
            }
            if (this.dx_function == null) {
                this.dx_function = CUDAModules.getLocalFunctionByModule("BNKernel3.cu", "normalize_delta_kernel");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void initKernel() {
        this.d_mean = CUDAMemoryManager.getDevice(this.C);
        this.d_var = CUDAMemoryManager.getDevice(this.C);
        this.d_dmean = CUDAMemoryManager.getDevice(this.C);
        this.d_dvar = CUDAMemoryManager.getDevice(this.C);
    }

    public void init() {
        initFunction();
        initKernel();
    }

    public void initForward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        if (tensor.number != this.N) {
            this.N = tensor.number;
            this.fastMeanParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_mean})});
            this.fastVarParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new NativePointerObject[]{this.d_var})});
            this.mwaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.runingMean.getGpuData()}), Pointer.to(new NativePointerObject[]{this.runingVar.getGpuData()}), Pointer.to(new int[]{this.C}), Pointer.to(new float[]{this.momentum})});
            this.d_z = CUDAMemoryManager.getDevice(this.N * this.C * this.H * this.W);
            this.normalizeParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.spatial}), Pointer.to(new float[]{this.eta})});
            this.normalize_test_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.runingMean.getGpuData()}), Pointer.to(new NativePointerObject[]{this.runingVar.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.spatial}), Pointer.to(new float[]{this.eta})});
        }
    }

    public void initBackward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        if (this.dgammaParameters == null) {
            this.dgammaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{this.d_z}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.spatial}), Pointer.to(new NativePointerObject[]{tensor5.getGpuData()})});
            this.dbetaParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor6.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.spatial})});
            this.dbetaFullParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor6.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C})});
            this.dxhatParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.spatial})});
            this.fastDmeanParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.spatial}), Pointer.to(new NativePointerObject[]{this.d_dmean})});
            this.fastDvarParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.spatial}), Pointer.to(new NativePointerObject[]{this.d_dvar})});
            this.dxParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.spatial}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{this.d_mean}), Pointer.to(new NativePointerObject[]{this.d_var}), Pointer.to(new NativePointerObject[]{this.d_dmean}), Pointer.to(new NativePointerObject[]{this.d_dvar}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.spatial}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()})});
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void forward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        initForward(tensor3, tensor, tensor2, tensor4);
        if (runModel != RunModel.TRAIN) {
            normalize_test(tensor3, tensor, tensor2, tensor4);
            return;
        }
        fast_mean();
        fast_var();
        mwa();
        normalize_train(tensor3, tensor, tensor2, tensor4);
    }

    public void fast_mean() {
        try {
            JCudaDriver.cuLaunchKernel(this.fast_mean_function, this.C, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastMeanParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fast_var() {
        try {
            JCudaDriver.cuLaunchKernel(this.fast_var_function, this.C, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastVarParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void normalize_test(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            JCudaDriver.cuLaunchKernel(this.normalize_test_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.normalize_test_Parameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void normalize_train(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            JCudaDriver.cuLaunchKernel(this.normalize_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.normalizeParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void mwa() {
        try {
            JCudaDriver.cuLaunchKernel(this.mwa_function, CAFFE_GET_BLOCKS(this.C), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.mwaParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        initBackward(tensor, tensor2, tensor3, tensor4, tensor5, tensor6);
        if (this.spatial == 1) {
            dbetaFull();
        } else {
            dbeta();
        }
        dgamma();
        dxhat();
        fastDmean();
        fastDvar();
        dx();
    }

    public void dgamma() {
        try {
            JCudaDriver.cuLaunchKernel(this.dgamma_function, this.C, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dgammaParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void dbeta() {
        try {
            JCudaDriver.cuLaunchKernel(this.dbeta_function, this.C, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dbetaParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void dbetaFull() {
        try {
            JCudaDriver.cuLaunchKernel(this.dbeta_full_function, CAFFE_GET_BLOCKS(this.C), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dbetaFullParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void dxhat() {
        try {
            JCudaDriver.cuLaunchKernel(this.dxhat_function, ((this.spatial - 1) / this.CAFFE_CUDA_NUM_THREADS) + 1, this.C, this.N, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dxhatParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fastDmean() {
        try {
            JCudaDriver.cuLaunchKernel(this.fast_dmean_function, this.C, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastDmeanParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void fastDvar() {
        try {
            JCudaDriver.cuLaunchKernel(this.fast_dvar_function, this.C, 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.fastDvarParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void dx() {
        try {
            JCudaDriver.cuLaunchKernel(this.dx_function, CAFFE_GET_BLOCKS(this.N * this.C * this.spatial), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.dxParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public void showDM(String str, CUdeviceptr cUdeviceptr, float[] fArr) {
        JCudaDriver.cuMemcpyDtoH(Pointer.to(fArr), cUdeviceptr, fArr.length * 4);
        System.out.println(str + ":" + JsonUtils.toJson(fArr));
    }
}
