package com.omega.engine.gpu.cudnn;

import com.omega.common.data.Tensor;
import com.omega.engine.nn.layer.gpu.BNBaseKernel;
import com.omega.engine.nn.layer.normalization.BNType;
import com.omega.engine.nn.network.RunModel;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnTensorDescriptor;

/* loaded from: input_file:com/omega/engine/gpu/cudnn/BNCudnnKernel.class */
public class BNCudnnKernel extends BNBaseKernel {
    private BNType bnType;
    private int C;
    private int H;
    private int W;
    private Tensor mean;
    private Tensor var;
    private cudnnTensorDescriptor normTensorDesc;
    private cudnnTensorDescriptor dstTensorDesc;
    private CUfunction normalize_test_function;
    private Pointer normalize_test_Parameters;
    private int mode = 1;
    private double eps = 1.0E-5d;
    private double momentum = 0.009999999776482582d;
    private Pointer alpha_P = Pointer.to(new float[]{1.0f});
    private Pointer beta_P = Pointer.to(new float[]{0.0f});
    private int CAFFE_CUDA_NUM_THREADS = 1024;

    public BNCudnnKernel(BNType bNType, int i, int i2, int i3, Tensor tensor, Tensor tensor2) {
        this.bnType = bNType;
        this.C = i;
        this.H = i2;
        this.W = i3;
        this.runingMean = tensor;
        this.runingVar = tensor2;
        init();
    }

    public void init() {
        if (this.bnType == BNType.fully_bn) {
            this.mode = 1;
            this.mean = new Tensor(1, 1, 1, this.W, true);
            this.var = new Tensor(1, 1, 1, this.W, true);
        } else {
            this.mode = 1;
            this.mean = new Tensor(1, 1, 1, this.C, true);
            this.var = new Tensor(1, 1, 1, this.C, true);
        }
        this.normTensorDesc = new cudnnTensorDescriptor();
        this.dstTensorDesc = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(this.normTensorDesc);
        JCudnn.cudnnCreateTensorDescriptor(this.dstTensorDesc);
        if (this.bnType == BNType.fully_bn) {
            JCudnn.cudnnSetTensor4dDescriptor(this.normTensorDesc, 0, 0, 1, this.W, 1, 1);
        } else {
            JCudnn.cudnnSetTensor4dDescriptor(this.normTensorDesc, 0, 0, 1, this.C, 1, 1);
        }
    }

    public void initForward(Tensor tensor) {
        if (tensor.number != this.N) {
            this.N = tensor.number;
            CudnnHandleManager.handle(JCudnn.cudnnDestroyTensorDescriptor(this.dstTensorDesc));
            CudnnHandleManager.handle(JCudnn.cudnnCreateTensorDescriptor(this.dstTensorDesc));
            if (this.bnType == BNType.fully_bn) {
                JCudnn.cudnnSetTensor4dDescriptor(this.dstTensorDesc, 0, 0, this.N, this.W, 1, 1);
            } else {
                JCudnn.cudnnSetTensor4dDescriptor(this.dstTensorDesc, 0, 0, this.N, this.C, this.H, this.W);
            }
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void forward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        initForward(tensor3);
        if (runModel == RunModel.TRAIN) {
            CudnnHandleManager.handle(JCudnn.cudnnBatchNormalizationForwardTraining(CudnnHandleManager.getHandle(), this.mode, this.alpha_P, this.beta_P, this.dstTensorDesc, tensor3.getGpuData(), this.dstTensorDesc, tensor4.getGpuData(), this.normTensorDesc, tensor.getGpuData(), tensor2.getGpuData(), this.momentum, this.runingMean.getGpuData(), this.runingVar.getGpuData(), this.eps, this.mean.getGpuData(), this.var.getGpuData()));
        } else {
            CudnnHandleManager.handle(JCudnn.cudnnBatchNormalizationForwardInference(CudnnHandleManager.getHandle(), this.mode, this.alpha_P, this.beta_P, this.dstTensorDesc, tensor3.getGpuData(), this.dstTensorDesc, tensor4.getGpuData(), this.normTensorDesc, tensor.getGpuData(), tensor2.getGpuData(), this.runingMean.getGpuData(), this.runingVar.getGpuData(), this.eps));
        }
    }

    public void forward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, int i, int i2) {
        initForward(tensor3);
        if (runModel == RunModel.TRAIN) {
            CudnnHandleManager.handle(JCudnn.cudnnBatchNormalizationForwardTraining(CudnnHandleManager.getHandle(), this.mode, this.alpha_P, this.beta_P, this.dstTensorDesc, tensor3.getGpuData().withByteOffset(i2 * i * tensor3.getOnceSize() * 4), this.dstTensorDesc, tensor4.getGpuData().withByteOffset(i2 * i * tensor3.getOnceSize() * 4), this.normTensorDesc, tensor.getGpuData(), tensor2.getGpuData(), this.momentum, this.runingMean.getGpuData(), this.runingVar.getGpuData(), this.eps, this.mean.getGpuData(), this.var.getGpuData()));
        } else {
            CudnnHandleManager.handle(JCudnn.cudnnBatchNormalizationForwardInference(CudnnHandleManager.getHandle(), this.mode, this.alpha_P, this.beta_P, this.dstTensorDesc, tensor3.getGpuData().withByteOffset(i2 * i * tensor3.getOnceSize() * 4), this.dstTensorDesc, tensor4.getGpuData().withByteOffset(i2 * i * tensor3.getOnceSize() * 4), this.normTensorDesc, tensor.getGpuData(), tensor2.getGpuData(), this.runingMean.getGpuData(), this.runingVar.getGpuData(), this.eps));
        }
    }

    public void normalize_test(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        try {
            this.normalize_test_Parameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{this.N * this.C * this.H * this.W}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor4.getGpuData()}), Pointer.to(new NativePointerObject[]{this.runingMean.getGpuData()}), Pointer.to(new NativePointerObject[]{this.runingVar.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor3.getGpuData()}), Pointer.to(new int[]{this.N}), Pointer.to(new int[]{this.C}), Pointer.to(new int[]{this.H * this.W}), Pointer.to(new float[]{(float) this.eps})});
            JCudaDriver.cuLaunchKernel(this.normalize_test_function, CAFFE_GET_BLOCKS(this.N * this.C * this.H * this.W), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.normalize_test_Parameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        CudnnHandleManager.handle(JCudnn.cudnnBatchNormalizationBackward(CudnnHandleManager.getHandle(), this.mode, this.alpha_P, this.beta_P, this.alpha_P, this.alpha_P, this.dstTensorDesc, tensor.getGpuData(), this.dstTensorDesc, tensor2.getGpuData(), this.dstTensorDesc, tensor3.getGpuData(), this.normTensorDesc, tensor4.getGpuData(), tensor5.getGpuData(), tensor6.getGpuData(), this.eps, this.mean.getGpuData(), this.var.getGpuData()));
    }

    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6, int i, int i2) {
        CudnnHandleManager.handle(JCudnn.cudnnBatchNormalizationBackward(CudnnHandleManager.getHandle(), this.mode, this.alpha_P, this.beta_P, this.alpha_P, this.alpha_P, this.dstTensorDesc, tensor.getGpuData().withByteOffset(i2 * i * tensor.getOnceSize() * 4), this.dstTensorDesc, tensor2.getGpuData().withByteOffset(i2 * i * tensor.getOnceSize() * 4), this.dstTensorDesc, tensor3.getGpuData().withByteOffset(i2 * i * tensor.getOnceSize() * 4), this.normTensorDesc, tensor4.getGpuData(), tensor5.getGpuData(), tensor6.getGpuData(), this.eps, this.mean.getGpuData(), this.var.getGpuData()));
    }
}
