package com.omega.engine.gpu.cudnn;

import com.omega.common.data.Tensor;
import com.omega.engine.nn.layer.gpu.BNBaseKernel;
import com.omega.engine.nn.network.RunModel;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnTensorDescriptor;

/* loaded from: input_file:com/omega/engine/gpu/cudnn/InstanceNormalizationCudnnKernel.class */
public class InstanceNormalizationCudnnKernel extends BNBaseKernel {
    private int C;
    private int H;
    private int W;
    private Tensor mean;
    private Tensor var;
    private cudnnTensorDescriptor normTensorDesc;
    private cudnnTensorDescriptor dstTensorDesc;
    private cudnnTensorDescriptor yTensorDesc;
    private int mode = 1;
    private double eps = 1.0E-5d;
    private double momentum = 0.009999999776482582d;
    private Pointer alpha_P = Pointer.to(new float[]{1.0f});
    private Pointer beta_P = Pointer.to(new float[]{0.0f});

    public InstanceNormalizationCudnnKernel(int i, int i2, int i3) {
        this.C = i;
        this.H = i2;
        this.W = i3;
        init();
    }

    public void init() {
        this.mode = 1;
        this.normTensorDesc = new cudnnTensorDescriptor();
        this.dstTensorDesc = new cudnnTensorDescriptor();
        this.yTensorDesc = new cudnnTensorDescriptor();
    }

    public void initForward(Tensor tensor) {
        if (tensor.number != this.N) {
            this.N = tensor.number;
            CudnnHandleManager.handle(JCudnn.cudnnDestroyTensorDescriptor(this.normTensorDesc));
            JCudnn.cudnnCreateTensorDescriptor(this.normTensorDesc);
            JCudnn.cudnnSetTensor4dDescriptor(this.normTensorDesc, 0, 0, 1, this.N * this.C, 1, 1);
            CudnnHandleManager.handle(JCudnn.cudnnDestroyTensorDescriptor(this.dstTensorDesc));
            CudnnHandleManager.handle(JCudnn.cudnnCreateTensorDescriptor(this.dstTensorDesc));
            JCudnn.cudnnSetTensor4dDescriptor(this.dstTensorDesc, 0, 0, 1, this.N * this.C, this.H, this.W);
            this.mean = new Tensor(1, 1, 1, this.N * this.C, true);
            this.var = new Tensor(1, 1, 1, this.N * this.C, true);
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void forward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        initForward(tensor3);
        CudnnHandleManager.handle(JCudnn.cudnnBatchNormalizationForwardTraining(CudnnHandleManager.getHandle(), this.mode, this.alpha_P, this.beta_P, this.dstTensorDesc, tensor3.getGpuData(), this.dstTensorDesc, tensor4.getGpuData(), this.normTensorDesc, tensor.getGpuData(), tensor2.getGpuData(), 0.1d, (Pointer) null, (Pointer) null, this.eps, this.mean.getGpuData(), this.var.getGpuData()));
    }

    @Override // com.omega.engine.nn.layer.gpu.BNBaseKernel
    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6) {
        CudnnHandleManager.handle(JCudnn.cudnnBatchNormalizationBackward(CudnnHandleManager.getHandle(), this.mode, this.alpha_P, this.beta_P, this.alpha_P, this.alpha_P, this.dstTensorDesc, tensor.getGpuData(), this.dstTensorDesc, tensor2.getGpuData(), this.dstTensorDesc, tensor3.getGpuData(), this.normTensorDesc, tensor4.getGpuData(), tensor5.getGpuData(), tensor6.getGpuData(), this.eps, this.mean.getGpuData(), this.var.getGpuData()));
    }
}
