package com.omega.engine.gpu.cudnn;

import com.omega.common.data.Tensor;
import com.omega.engine.nn.layer.gpu.ConvBaseKernel;
import com.omega.engine.nn.network.Network;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnConvolutionBwdDataAlgoPerf;
import jcuda.jcudnn.cudnnConvolutionBwdFilterAlgoPerf;
import jcuda.jcudnn.cudnnConvolutionDescriptor;
import jcuda.jcudnn.cudnnConvolutionFwdAlgoPerf;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;

/* loaded from: input_file:com/omega/engine/gpu/cudnn/ConvTransposeCudnnKernel.class */
public class ConvTransposeCudnnKernel extends ConvBaseKernel {
    private int C;
    private int H;
    private int W;
    private int ko;
    private int kh;
    private int kw;
    private int on;
    private int oc;
    private int oh;
    private int ow;
    private int padding;
    private int output_padding;
    private int dilation;
    private int stride;
    private int fw_algo;
    private int bkf_algo;
    private int bkd_algo;
    private Network network;
    private int convAlgorithm = -1;
    private Pointer alpha_P = Pointer.to(new float[]{1.0f});
    private Pointer beta_P = Pointer.to(new float[]{0.0f});
    private cudnnTensorDescriptor xDesc = new cudnnTensorDescriptor();
    private cudnnFilterDescriptor kernelDesc = new cudnnFilterDescriptor();
    private cudnnTensorDescriptor yDesc = new cudnnTensorDescriptor();
    private cudnnConvolutionDescriptor convDesc = new cudnnConvolutionDescriptor();

    public ConvTransposeCudnnKernel(Network network, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10) {
        this.padding = 0;
        this.output_padding = 0;
        this.dilation = 1;
        this.stride = 1;
        this.network = network;
        this.C = i;
        this.H = i2;
        this.W = i3;
        this.ko = i4;
        this.kh = i5;
        this.kw = i6;
        this.stride = i7;
        this.padding = i8;
        this.output_padding = i10;
        this.dilation = i9;
        JCudnn.cudnnCreateTensorDescriptor(this.xDesc);
        JCudnn.cudnnCreateFilterDescriptor(this.kernelDesc);
        JCudnn.cudnnCreateTensorDescriptor(this.yDesc);
        JCudnn.cudnnCreateConvolutionDescriptor(this.convDesc);
    }

    public void init(int i) {
        if (this.N != i) {
            this.N = i;
            int[] iArr = {this.padding, this.padding};
            int[] iArr2 = {this.C, this.ko, this.kh, this.kw};
            int[] iArr3 = {this.dilation, this.dilation};
            JCudnn.cudnnSetTensor4dDescriptor(this.xDesc, 0, 0, this.N, this.C, this.H, this.W);
            JCudnn.cudnnSetFilterNdDescriptor(this.kernelDesc, 0, 0, 4, iArr2);
            JCudnn.cudnnSetConvolutionNdDescriptor(this.convDesc, 2, iArr, new int[]{this.stride, this.stride}, iArr3, 1, 0);
            this.on = this.N;
            this.oc = this.ko;
            this.oh = (((this.H - 1) * this.stride) - (2 * iArr[0])) + (this.dilation * (this.kh - 1)) + this.output_padding + 1;
            this.ow = (((this.W - 1) * this.stride) - (2 * iArr[1])) + (this.dilation * (this.kw - 1)) + this.output_padding + 1;
            JCudnn.cudnnSetTensor4dDescriptor(this.yDesc, 0, 0, this.on, this.oc, this.oh, this.ow);
            this.fw_algo = getForwardAlgorithm(this.convAlgorithm, this.yDesc, this.kernelDesc, this.convDesc, this.xDesc);
            this.bkf_algo = getBKFGO(2, this.yDesc, this.xDesc, this.kernelDesc, this.convDesc);
            this.bkd_algo = getBKDGO(2, this.yDesc, this.xDesc, this.kernelDesc, this.convDesc);
            getWorkSpace();
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.ConvBaseKernel
    public void convTranspose(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        init(tensor.number);
        handle(JCudnn.cudnnConvolutionBackwardData(CudnnHandleManager.getHandle(), this.alpha_P, this.kernelDesc, tensor2.getGpuData(), this.xDesc, tensor.getGpuData(), this.convDesc, this.bkd_algo, this.network.workspace, this.network.workspaceSize, this.beta_P, this.yDesc, tensor3.getGpuData()));
    }

    @Override // com.omega.engine.nn.layer.gpu.ConvBaseKernel
    public void dw(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        handle(JCudnn.cudnnConvolutionBackwardFilter(CudnnHandleManager.getHandle(), this.alpha_P, this.yDesc, tensor2.getGpuData(), this.xDesc, tensor.getGpuData(), this.convDesc, this.bkf_algo, this.network.workspace, this.network.workspaceSize, this.beta_P, this.kernelDesc, tensor3.getGpuData()));
    }

    @Override // com.omega.engine.nn.layer.gpu.ConvBaseKernel
    public void dx(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        handle(JCudnn.cudnnConvolutionForward(CudnnHandleManager.getHandle(), this.alpha_P, this.yDesc, tensor.getGpuData(), this.kernelDesc, tensor2.getGpuData(), this.convDesc, this.fw_algo, this.network.workspace, this.network.workspaceSize, this.beta_P, this.xDesc, tensor3.getGpuData()));
    }

    public int getBKDGO(int i, cudnnTensorDescriptor cudnntensordescriptor, cudnnTensorDescriptor cudnntensordescriptor2, cudnnFilterDescriptor cudnnfilterdescriptor, cudnnConvolutionDescriptor cudnnconvolutiondescriptor) {
        int[] iArr = {12};
        cudnnConvolutionBwdDataAlgoPerf[] cudnnconvolutionbwddataalgoperfArr = new cudnnConvolutionBwdDataAlgoPerf[12];
        System.out.println("Testing cudnnFindConvolutionBackwardDataAlgorithm ...");
        JCudnn.cudnnFindConvolutionBackwardDataAlgorithm(CudnnHandleManager.getHandle(), cudnnfilterdescriptor, cudnntensordescriptor2, cudnnconvolutiondescriptor, cudnntensordescriptor, 12, iArr, cudnnconvolutionbwddataalgoperfArr);
        int i2 = iArr[0];
        for (int i3 = 0; i3 < i2; i3++) {
            System.out.printf("^^^^ for Algo %d: %f time requiring %d memory %s \n", Integer.valueOf(cudnnconvolutionbwddataalgoperfArr[i3].algo), Float.valueOf(cudnnconvolutionbwddataalgoperfArr[i3].time), Long.valueOf(cudnnconvolutionbwddataalgoperfArr[i3].memory), "[" + checkError(cudnnconvolutionbwddataalgoperfArr[i3].status) + "]");
        }
        return cudnnconvolutionbwddataalgoperfArr[0].algo;
    }

    public int getBKFGO(int i, cudnnTensorDescriptor cudnntensordescriptor, cudnnTensorDescriptor cudnntensordescriptor2, cudnnFilterDescriptor cudnnfilterdescriptor, cudnnConvolutionDescriptor cudnnconvolutiondescriptor) {
        int[] iArr = {-1};
        cudnnConvolutionBwdFilterAlgoPerf[] cudnnconvolutionbwdfilteralgoperfArr = new cudnnConvolutionBwdFilterAlgoPerf[12];
        System.out.println("Testing cudnnFindConvolutionBackwardFilterAlgorithm ...");
        JCudnn.cudnnFindConvolutionBackwardFilterAlgorithm(CudnnHandleManager.getHandle(), cudnntensordescriptor, cudnntensordescriptor2, cudnnconvolutiondescriptor, cudnnfilterdescriptor, 6, iArr, cudnnconvolutionbwdfilteralgoperfArr);
        int i2 = iArr[0];
        for (int i3 = 0; i3 < i2; i3++) {
            System.out.printf("^^^^ for Algo %d: %f time requiring %d memory %s \n", Integer.valueOf(cudnnconvolutionbwdfilteralgoperfArr[i3].algo), Float.valueOf(cudnnconvolutionbwdfilteralgoperfArr[i3].time), Long.valueOf(cudnnconvolutionbwdfilteralgoperfArr[i3].memory), "[" + checkError(cudnnconvolutionbwdfilteralgoperfArr[i3].status) + "]");
        }
        return cudnnconvolutionbwdfilteralgoperfArr[0].algo;
    }

    public int getForwardAlgorithm(int i, cudnnTensorDescriptor cudnntensordescriptor, cudnnFilterDescriptor cudnnfilterdescriptor, cudnnConvolutionDescriptor cudnnconvolutiondescriptor, cudnnTensorDescriptor cudnntensordescriptor2) {
        if (i >= 0) {
            return i;
        }
        int[] iArr = {-1};
        cudnnConvolutionFwdAlgoPerf[] cudnnconvolutionfwdalgoperfArr = new cudnnConvolutionFwdAlgoPerf[16];
        System.out.println("Testing cudnnFindConvolutionForwardAlgorithm ...");
        JCudnn.cudnnFindConvolutionForwardAlgorithm(CudnnHandleManager.getHandle(), cudnntensordescriptor, cudnnfilterdescriptor, cudnnconvolutiondescriptor, cudnntensordescriptor2, 8, iArr, cudnnconvolutionfwdalgoperfArr);
        int i2 = iArr[0];
        for (int i3 = 0; i3 < i2; i3++) {
            System.out.printf("^^^^ for Algo %d: %f time requiring %d memory %s \n", Integer.valueOf(cudnnconvolutionfwdalgoperfArr[i3].algo), Float.valueOf(cudnnconvolutionfwdalgoperfArr[i3].time), Long.valueOf(cudnnconvolutionfwdalgoperfArr[i3].memory), "[" + checkError(cudnnconvolutionfwdalgoperfArr[i3].status) + "]");
        }
        return cudnnconvolutionfwdalgoperfArr[0].algo;
    }

    public void getWorkSpace() {
        if (this.network.workspace == null) {
            this.network.workspace = new Pointer();
        }
        long j = 0;
        long[] jArr = {0};
        handle(JCudnn.cudnnGetConvolutionForwardWorkspaceSize(CudnnHandleManager.getHandle(), this.yDesc, this.kernelDesc, this.convDesc, this.xDesc, this.fw_algo, jArr));
        if (jArr[0] > 0) {
            j = jArr[0];
        }
        handle(JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize(CudnnHandleManager.getHandle(), this.yDesc, this.xDesc, this.convDesc, this.kernelDesc, this.bkf_algo, jArr));
        if (jArr[0] > j) {
            j = jArr[0];
        }
        handle(JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize(CudnnHandleManager.getHandle(), this.kernelDesc, this.xDesc, this.convDesc, this.yDesc, this.bkd_algo, jArr));
        if (jArr[0] > j) {
            j = jArr[0];
        }
        if (j > this.network.workspaceSize) {
            this.network.workspaceSize = j;
            JCuda.cudaFree(this.network.workspace);
            JCuda.cudaMalloc(this.network.workspace, this.network.workspaceSize);
        }
    }

    public static void handle(int i) {
        if (i != 0) {
            System.err.println(cudnnStatus.stringFor(i));
            throw new RuntimeException(cudnnStatus.stringFor(i));
        }
    }

    public static String checkError(int i) {
        return i != 0 ? cudnnStatus.stringFor(i) : "success";
    }

    @Override // com.omega.engine.nn.layer.gpu.ConvBaseKernel
    public void conv(Tensor tensor, Tensor tensor2, Tensor tensor3) {
    }
}
