package com.omega.engine.gpu.cudnn;

import com.omega.common.data.Tensor;
import com.omega.engine.nn.layer.gpu.PoolingBaseKernel;
import com.omega.engine.pooling.PoolingType;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnPoolingDescriptor;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;

/* loaded from: input_file:com/omega/engine/gpu/cudnn/PoolingCudnnKernel.class */
public class PoolingCudnnKernel extends PoolingBaseKernel {
    private int C;
    private int H;
    private int W;
    private int oc;
    private int oh;
    private int ow;
    private int pWidth;
    private int pHeight;
    private int padding;
    private int stride;
    private Pointer alpha_P = Pointer.to(new float[]{1.0f});
    private Pointer beta_P = Pointer.to(new float[]{0.0f});
    private cudnnTensorDescriptor xDesc = new cudnnTensorDescriptor();
    private cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
    private cudnnTensorDescriptor yDesc = new cudnnTensorDescriptor();

    public PoolingCudnnKernel(PoolingType poolingType, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9) {
        this.padding = 0;
        this.stride = 1;
        this.C = i;
        this.H = i2;
        this.W = i3;
        this.oc = i;
        this.oh = i4;
        this.ow = i5;
        this.pWidth = i6;
        this.pHeight = i7;
        this.padding = i9 / 2;
        this.stride = i8;
        handle(JCudnn.cudnnCreateTensorDescriptor(this.xDesc));
        handle(JCudnn.cudnnCreatePoolingDescriptor(this.poolingDesc));
        handle(JCudnn.cudnnCreateTensorDescriptor(this.yDesc));
        int i10 = 0;
        switch (poolingType) {
            case MAX_POOLING:
                i10 = 0;
                break;
            case AVG_POOLING:
                i10 = 2;
                break;
        }
        handle(JCudnn.cudnnSetPooling2dDescriptor(this.poolingDesc, i10, 1, this.pHeight, this.pWidth, this.padding, this.padding, this.stride, this.stride));
    }

    public void init(int i) {
        if (this.N != i) {
            this.N = i;
            handle(JCudnn.cudnnSetTensor4dDescriptor(this.xDesc, 0, 0, this.N, this.C, this.H, this.W));
            handle(JCudnn.cudnnSetTensor4dDescriptor(this.yDesc, 0, 0, this.N, this.oc, this.oh, this.ow));
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.PoolingBaseKernel
    public void forward(Tensor tensor, Tensor tensor2) {
        init(tensor.number);
        handle(JCudnn.cudnnPoolingForward(CudnnHandleManager.getHandle(), this.poolingDesc, this.alpha_P, this.xDesc, tensor.getGpuData(), this.beta_P, this.yDesc, tensor2.getGpuData()));
    }

    @Override // com.omega.engine.nn.layer.gpu.PoolingBaseKernel
    public void backward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        handle(JCudnn.cudnnPoolingBackward(CudnnHandleManager.getHandle(), this.poolingDesc, this.alpha_P, this.yDesc, tensor2.getGpuData(), this.yDesc, tensor3.getGpuData(), this.xDesc, tensor.getGpuData(), this.beta_P, this.xDesc, tensor4.getGpuData()));
    }

    public static void handle(int i) {
        if (i != 0) {
            System.err.println(cudnnStatus.stringFor(i));
            throw new RuntimeException(cudnnStatus.stringFor(i));
        }
    }

    public static String checkError(int i) {
        return i != 0 ? cudnnStatus.stringFor(i) : "success";
    }
}
