package com.omega.engine.gpu.cudnn;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.nn.layer.gpu.RNNBaseKernel;
import com.omega.engine.nn.network.RunModel;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnDropoutDescriptor;
import jcuda.jcudnn.cudnnRNNDataDescriptor;
import jcuda.jcudnn.cudnnRNNDescriptor;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.jcurand.JCurand;
import jcuda.jcurand.curandGenerator;
import jcuda.runtime.JCuda;

/* loaded from: input_file:com/omega/engine/gpu/cudnn/RNNCudnnKernelV8.class */
public class RNNCudnnKernelV8 extends RNNBaseKernel {
    public int layerNum;
    public int rnnMode;
    public int inputSize;
    public int hiddenSize;
    public int N;
    private boolean bidirectional;
    private Pointer workspace;
    private Pointer reserveSpace;
    private cudnnRNNDescriptor rnnDesc;
    private cudnnRNNDataDescriptor xDesc;
    private cudnnRNNDataDescriptor yDesc;
    private cudnnTensorDescriptor hDesc;
    private cudnnTensorDescriptor cDesc;
    private cudnnDropoutDescriptor dropoutDesc;
    private float dropout;
    private boolean hasBias;
    private int dataType;
    private int mathPrec;
    private int mathType;
    private int bidirectionalScale;
    private int bidmod;
    private int hidTensorSz;
    public int persistent = 0;
    private int CUDARNNAlgo = 0;
    private long workSize = 0;
    private long reserveSize = 0;
    private Pointer seqP = null;
    private long[] weightSpaceSize = {0};

    public RNNCudnnKernelV8(int i, int i2, int i3, int i4, boolean z, int i5, float f, boolean z2) {
        this.layerNum = 1;
        this.dropout = 0.0f;
        this.hasBias = true;
        this.bidirectionalScale = 1;
        this.bidmod = 0;
        this.hasBias = z2;
        this.seqLength = i;
        this.dropout = f;
        this.inputSize = i3;
        this.hiddenSize = i4;
        this.layerNum = i2;
        this.bidirectional = z;
        switch (i5) {
            case 0:
                this.rnnMode = 0;
                break;
            case 1:
                this.rnnMode = 1;
                break;
            case 2:
                this.rnnMode = 2;
                break;
            case 3:
                this.rnnMode = 3;
                break;
            default:
                throw new RuntimeException("RNN mode is only support 0:rnn_relu,1:rnn_tanh,1:lstm,2:gru");
        }
        if (this.bidirectional) {
            this.bidirectionalScale = 2;
            this.bidmod = 1;
        }
        init();
    }

    public void init() {
        this.xDesc = new cudnnRNNDataDescriptor();
        this.yDesc = new cudnnRNNDataDescriptor();
        this.hDesc = new cudnnTensorDescriptor();
        this.cDesc = new cudnnTensorDescriptor();
        this.dropoutDesc = new cudnnDropoutDescriptor();
        this.rnnDesc = new cudnnRNNDescriptor();
        this.workspace = new Pointer();
        this.reserveSpace = new Pointer();
    }

    @Override // com.omega.engine.nn.layer.gpu.RNNBaseKernel
    public void init(int i, int i2) {
        if (this.N != i) {
            this.N = i;
            this.seqLength = i2;
            int i3 = this.N / this.seqLength;
            this.hidTensorSz = this.layerNum * i3 * this.hiddenSize * this.bidirectionalScale;
            JCudnn.cudnnCreateRNNDataDescriptor(this.xDesc);
            JCudnn.cudnnCreateRNNDataDescriptor(this.yDesc);
            int[] iArr = new int[i3];
            Pointer pointer = Pointer.to(new float[]{0.0f});
            MatrixUtils.fill(iArr, 0, i3, this.seqLength);
            this.seqP = CUDAMemoryManager.getPointer(i3, 4);
            JCuda.cudaMemcpy(this.seqP, Pointer.to(iArr), i3 * 4, 1);
            JCudnn.cudnnSetRNNDataDescriptor(this.xDesc, this.dataType, 1, this.seqLength, i3, this.inputSize, iArr, pointer);
            JCudnn.cudnnSetRNNDataDescriptor(this.yDesc, this.dataType, 1, this.seqLength, i3, this.hiddenSize * this.bidirectionalScale, iArr, pointer);
            JCudnn.cudnnCreateTensorDescriptor(this.hDesc);
            JCudnn.cudnnCreateTensorDescriptor(this.cDesc);
            int[] iArr2 = {this.layerNum * this.bidirectionalScale, i3, this.hiddenSize};
            int[] iArr3 = {iArr2[2] * iArr2[1], iArr2[2], 1};
            JCudnn.cudnnSetTensorNdDescriptor(this.hDesc, this.dataType, 3, iArr2, iArr3);
            JCudnn.cudnnSetTensorNdDescriptor(this.cDesc, this.dataType, 3, iArr2, iArr3);
            JCudnn.cudnnCreateDropoutDescriptor(this.dropoutDesc);
            long[] jArr = {0};
            Pointer pointer2 = new Pointer();
            JCudnn.cudnnDropoutGetStatesSize(CudnnHandleManager.getHandle(), jArr);
            long j = jArr[0];
            JCuda.cudaMalloc(pointer2, j);
            handle(JCudnn.cudnnSetDropoutDescriptor(this.dropoutDesc, CudnnHandleManager.getHandle(), this.dropout, pointer2, j, 1337L));
            JCudnn.cudnnCreateRNNDescriptor(this.rnnDesc);
            int i4 = this.hiddenSize;
            this.dataType = 0;
            this.mathPrec = 0;
            this.mathType = 0;
            if ((this.dataType == 2 && this.mathPrec != 2 && this.mathPrec != 0) || ((this.dataType == 0 && this.mathPrec != 0) || (this.dataType == 1 && this.mathPrec != 1))) {
                System.err.println("[ERROR] Inconsistent parameter: dataType does not match mathPrecision!");
            }
            if ((this.dataType == 0 && this.mathType != 0 && this.mathType != 2) || (this.dataType == 1 && this.mathType != 0)) {
                System.err.println("[ERROR] Inconsistent parameter: dataType does not match mathType!");
            }
            int i5 = 2;
            if (!this.hasBias) {
                i5 = 0;
            }
            handle(JCudnn.cudnnSetRNNDescriptor_v8(this.rnnDesc, this.CUDARNNAlgo, this.rnnMode, i5, this.bidmod, 0, this.dataType, this.mathPrec, this.mathType, this.inputSize, this.hiddenSize, i4, this.layerNum, this.dropoutDesc, 0));
            long[] jArr2 = {0};
            long[] jArr3 = {0};
            handle(JCudnn.cudnnGetRNNTempSpaceSizes(CudnnHandleManager.getHandle(), this.rnnDesc, 1, this.xDesc, jArr2, jArr3));
            this.workSize = jArr2[0];
            this.reserveSize = jArr3[0];
            JCuda.cudaMalloc(this.workspace, this.workSize);
            JCuda.cudaMalloc(this.reserveSpace, this.reserveSize);
            JCuda.cudaDeviceSynchronize();
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.RNNBaseKernel
    public long weightSize() {
        JCudnn.cudnnGetRNNWeightSpaceSize(CudnnHandleManager.getHandle(), this.rnnDesc, getWeightSpaceSize());
        return getWeightSpaceSize()[0];
    }

    @Override // com.omega.engine.nn.layer.gpu.RNNBaseKernel
    public void initWeights(Tensor tensor) {
        float sqrt = (float) Math.sqrt(2.0d / (this.inputSize + this.hiddenSize));
        curandGenerator curandgenerator = new curandGenerator();
        JCurand.curandCreateGenerator(curandgenerator, 100);
        JCurand.curandSetPseudoRandomGeneratorSeed(curandgenerator, 1337L);
        JCurand.curandGenerateNormal(curandgenerator, tensor.getGpuData(), tensor.getDataLength(), 0.0f, sqrt);
    }

    @Override // com.omega.engine.nn.layer.gpu.RNNBaseKernel
    public void forward(RunModel runModel, Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6, Tensor tensor7) {
        if (runModel == RunModel.TRAIN) {
            handle(JCudnn.cudnnRNNForward(CudnnHandleManager.getHandle(), this.rnnDesc, 1, this.seqP, this.xDesc, tensor.getGpuData(), this.yDesc, tensor5.getGpuData(), this.hDesc, tensor2.getGpuData(), tensor6.getGpuData(), this.cDesc, tensor3.getGpuData(), tensor7.getGpuData(), getWeightSpaceSize()[0], tensor4.getGpuData(), this.workSize, this.workspace, this.reserveSize, this.reserveSpace));
        } else {
            handle(JCudnn.cudnnRNNForward(CudnnHandleManager.getHandle(), this.rnnDesc, 0, this.seqP, this.xDesc, tensor.getGpuData(), this.yDesc, tensor5.getGpuData(), this.hDesc, tensor2.getGpuData(), tensor6.getGpuData(), this.cDesc, tensor3.getGpuData(), tensor7.getGpuData(), getWeightSpaceSize()[0], tensor4.getGpuData(), this.workSize, this.workspace, this.reserveSize, this.reserveSpace));
        }
    }

    @Override // com.omega.engine.nn.layer.gpu.RNNBaseKernel
    public void dw(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5) {
        tensor5.clearGPU();
        JCudnn.cudnnRNNBackwardWeights_v8(CudnnHandleManager.getHandle(), this.rnnDesc, 0, this.seqP, this.xDesc, tensor3.getGpuData(), this.hDesc, tensor4.getGpuData(), this.yDesc, tensor2.getGpuData(), tensor5.getDataLength() * 4, tensor5.getGpuData(), this.workSize, this.workspace, this.reserveSize, this.reserveSpace);
    }

    @Override // com.omega.engine.nn.layer.gpu.RNNBaseKernel
    public void dx(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6, Tensor tensor7, Tensor tensor8, Tensor tensor9, Tensor tensor10) {
        Pointer pointer = null;
        if (tensor2 != null) {
            pointer = tensor2.getGpuData();
        }
        handle(JCudnn.cudnnRNNBackwardData_v8(CudnnHandleManager.getHandle(), this.rnnDesc, this.seqP, this.yDesc, tensor4.getGpuData(), tensor.getGpuData(), this.xDesc, tensor8.getGpuData(), this.hDesc, tensor5.getGpuData(), pointer, tensor9.getGpuData(), this.cDesc, tensor6.getGpuData(), tensor3.getGpuData(), tensor10.getGpuData(), tensor7.getDataLength() * 4, tensor7.getGpuData(), this.workSize, this.workspace, this.reserveSize, this.reserveSpace));
    }

    public static void handle(int i) {
        if (i != 0) {
            System.err.println(cudnnStatus.stringFor(i));
            throw new RuntimeException(cudnnStatus.stringFor(i));
        }
    }

    public static String checkError(int i) {
        return i != 0 ? cudnnStatus.stringFor(i) : "success";
    }

    public long[] getWeightSpaceSize() {
        return this.weightSpaceSize;
    }

    private static void initGPUData(Pointer pointer, int i, float f, float f2) {
        JCuda.cudaMemcpy(pointer, Pointer.to(RandomUtils.order(i, f, f2)), i * 4, 1);
    }
}
