package com.omega.example.rnn.test;

import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.gpu.cudnn.CudnnHandleManager;
import java.io.File;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.Arrays;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnDropoutDescriptor;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnRNNDescriptor;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaEvent_t;
import jcuda.runtime.cudaStream_t;

/* loaded from: input_file:com/omega/example/rnn/test/JCudnnRnnExample.class */
public class JCudnnRnnExample {
    public static void main(String[] strArr) throws Exception {
        CUDAModules.initContext();
        mainImpl(new String[]{"20", "1", "128", "64", "64", "2"});
    }

    public static void mainImpl(String[] strArr) throws Exception {
        JCuda.setExceptionsEnabled(true);
        JCudnn.setExceptionsEnabled(true);
        PrintWriter printWriter = new PrintWriter(new File("./results.txt"));
        if (strArr.length != 6) {
            System.out.printf("Usage:\n", new Object[0]);
            System.out.printf("./RNN <seqLength> <numLayers> <hiddenSize> <miniBatch> <mode>\n", new Object[0]);
            System.out.printf("Modes: 0 = RNN_RELU, 1 = RNN_TANH, 2 = LSTM, 3 = GRU\n", new Object[0]);
            return;
        }
        int parseInt = Integer.parseInt(strArr[0]);
        int parseInt2 = Integer.parseInt(strArr[1]);
        int parseInt3 = Integer.parseInt(strArr[2]);
        int parseInt4 = Integer.parseInt(strArr[3]);
        int parseInt5 = Integer.parseInt(strArr[4]);
        Integer.parseInt(strArr[5]);
        cudnnHandle handle = CudnnHandleManager.getHandle();
        Pointer pointer = new Pointer();
        Pointer pointer2 = new Pointer();
        Pointer pointer3 = new Pointer();
        Pointer pointer4 = new Pointer();
        Pointer pointer5 = new Pointer();
        Pointer pointer6 = new Pointer();
        Pointer pointer7 = new Pointer();
        Pointer pointer8 = new Pointer();
        Pointer pointer9 = new Pointer();
        Pointer pointer10 = new Pointer();
        Pointer pointer11 = new Pointer();
        Pointer pointer12 = new Pointer();
        JCuda.cudaMalloc(pointer, parseInt * parseInt4 * parseInt5 * 4);
        System.out.println(parseInt * parseInt4 * parseInt5);
        JCuda.cudaMalloc(pointer2, parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        JCuda.cudaMalloc(pointer3, parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        JCuda.cudaMalloc(pointer4, parseInt * parseInt4 * parseInt5 * 4);
        JCuda.cudaMalloc(pointer5, parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        JCuda.cudaMalloc(pointer6, parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        JCuda.cudaMalloc(pointer7, parseInt * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        System.out.println(parseInt * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1));
        JCuda.cudaMalloc(pointer8, parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        System.out.println("hy:" + (parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1)));
        JCuda.cudaMalloc(pointer9, parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        JCuda.cudaMalloc(pointer10, parseInt * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        JCuda.cudaMalloc(pointer11, parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        JCuda.cudaMalloc(pointer12, parseInt2 * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1) * 4);
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor2 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor3 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor4 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor5 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor6 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor7 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor8 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor[] cudnntensordescriptorArr = new cudnnTensorDescriptor[parseInt];
        cudnnTensorDescriptor[] cudnntensordescriptorArr2 = new cudnnTensorDescriptor[parseInt];
        cudnnTensorDescriptor[] cudnntensordescriptorArr3 = new cudnnTensorDescriptor[parseInt];
        cudnnTensorDescriptor[] cudnntensordescriptorArr4 = new cudnnTensorDescriptor[parseInt];
        int[] iArr = new int[3];
        int[] iArr2 = new int[3];
        for (int i = 0; i < parseInt; i++) {
            cudnntensordescriptorArr[i] = new cudnnTensorDescriptor();
            cudnntensordescriptorArr2[i] = new cudnnTensorDescriptor();
            cudnntensordescriptorArr3[i] = new cudnnTensorDescriptor();
            cudnntensordescriptorArr4[i] = new cudnnTensorDescriptor();
            JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptorArr[i]);
            JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptorArr2[i]);
            JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptorArr3[i]);
            JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptorArr4[i]);
            iArr[0] = parseInt5;
            iArr[1] = parseInt4;
            iArr[2] = 1;
            iArr2[0] = iArr[2] * iArr[1];
            iArr2[1] = iArr[2];
            iArr2[2] = 1;
            JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptorArr[i], 0, 3, iArr, iArr2);
            JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptorArr3[i], 0, 3, iArr, iArr2);
            iArr[0] = parseInt5;
            iArr[1] = 0 != 0 ? parseInt3 * 2 : parseInt3;
            iArr[2] = 1;
            iArr2[0] = iArr[2] * iArr[1];
            iArr2[1] = iArr[2];
            iArr2[2] = 1;
            JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptorArr2[i], 0, 3, iArr, iArr2);
            JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptorArr4[i], 0, 3, iArr, iArr2);
        }
        iArr[0] = parseInt2 * (0 != 0 ? 2 : 1);
        iArr[1] = parseInt5;
        iArr[2] = parseInt3;
        System.out.println(iArr[0] + ":" + iArr[1] + ":" + iArr[2]);
        iArr2[0] = iArr[2] * iArr[1];
        iArr2[1] = iArr[2];
        iArr2[2] = 1;
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor2);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor3);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor4);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor5);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor6);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor7);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor8);
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor, 0, 3, iArr, iArr2);
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor2, 0, 3, iArr, iArr2);
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor3, 0, 3, iArr, iArr2);
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor4, 0, 3, iArr, iArr2);
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor5, 0, 3, iArr, iArr2);
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor6, 0, 3, iArr, iArr2);
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor7, 0, 3, iArr, iArr2);
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor8, 0, 3, iArr, iArr2);
        cudnnDropoutDescriptor cudnndropoutdescriptor = new cudnnDropoutDescriptor();
        JCudnn.cudnnCreateDropoutDescriptor(cudnndropoutdescriptor);
        long[] jArr = {0};
        Pointer pointer13 = new Pointer();
        JCudnn.cudnnDropoutGetStatesSize(handle, jArr);
        long j = jArr[0];
        JCuda.cudaMalloc(pointer13, j);
        JCudnn.cudnnSetDropoutDescriptor(cudnndropoutdescriptor, handle, 0.0f, pointer13, j, 1337L);
        cudnnRNNDescriptor cudnnrnndescriptor = new cudnnRNNDescriptor();
        JCudnn.cudnnCreateRNNDescriptor(cudnnrnndescriptor);
        JCudnn.cudnnSetRNNDescriptor_v6(handle, cudnnrnndescriptor, parseInt3, parseInt2, cudnndropoutdescriptor, 0, 0 != 0 ? 1 : 0, 0, 0, 0);
        Pointer pointer14 = new Pointer();
        Pointer pointer15 = new Pointer();
        cudnnFilterDescriptor cudnnfilterdescriptor = new cudnnFilterDescriptor();
        cudnnFilterDescriptor cudnnfilterdescriptor2 = new cudnnFilterDescriptor();
        JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor);
        JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor2);
        long[] jArr2 = {0};
        JCudnn.cudnnGetRNNParamsSize(handle, cudnnrnndescriptor, cudnntensordescriptorArr[0], jArr2, 0);
        long j2 = jArr2[0];
        int[] iArr3 = {(int) (j2 / 4), 1, 1};
        System.out.println("weightsSize:" + j2);
        JCudnn.cudnnSetFilterNdDescriptor(cudnnfilterdescriptor, 0, 0, 3, iArr3);
        JCudnn.cudnnSetFilterNdDescriptor(cudnnfilterdescriptor2, 0, 0, 3, iArr3);
        JCuda.cudaMalloc(pointer14, j2);
        JCuda.cudaMalloc(pointer15, j2);
        Pointer pointer16 = new Pointer();
        Pointer pointer17 = new Pointer();
        long[] jArr3 = {0};
        long[] jArr4 = {0};
        JCudnn.cudnnGetRNNWorkspaceSize(handle, cudnnrnndescriptor, parseInt, cudnntensordescriptorArr, jArr3);
        long j3 = jArr3[0];
        JCudnn.cudnnGetRNNTrainingReserveSize(handle, cudnnrnndescriptor, parseInt, cudnntensordescriptorArr, jArr4);
        long j4 = jArr4[0];
        JCuda.cudaMalloc(pointer16, j3);
        JCuda.cudaMalloc(pointer17, j4);
        initGPUData(pointer, parseInt * parseInt4 * parseInt5, 1.0f);
        initGPUData(pointer10, parseInt * parseInt3 * parseInt5 * (0 != 0 ? 2 : 1), 1.0f);
        int i2 = 0;
        if (0 == 0 || 0 == 1) {
            i2 = 2;
        } else if (0 == 2) {
            i2 = 8;
        } else if (0 == 3) {
            i2 = 6;
        }
        int i3 = 0;
        while (true) {
            if (i3 >= parseInt2 * (0 != 0 ? 2 : 1)) {
                break;
            }
            for (int i4 = 0; i4 < i2; i4++) {
                cudnnFilterDescriptor cudnnfilterdescriptor3 = new cudnnFilterDescriptor();
                JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor3);
                Pointer pointer18 = new Pointer();
                JCudnn.cudnnGetRNNLinLayerMatrixParams(handle, cudnnrnndescriptor, i3, cudnntensordescriptorArr[0], cudnnfilterdescriptor, pointer14, i4, cudnnfilterdescriptor3, pointer18);
                int[] iArr4 = {0};
                int[] iArr5 = {0};
                int[] iArr6 = {0};
                int[] iArr7 = new int[3];
                JCudnn.cudnnGetFilterNdDescriptor(cudnnfilterdescriptor3, 3, iArr4, iArr5, iArr6, iArr7);
                initGPUData(pointer18, iArr7[0] * iArr7[1] * iArr7[2], 1.0f / ((iArr7[0] * iArr7[1]) * iArr7[2]));
                JCudnn.cudnnDestroyFilterDescriptor(cudnnfilterdescriptor3);
                cudnnFilterDescriptor cudnnfilterdescriptor4 = new cudnnFilterDescriptor();
                JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor4);
                Pointer pointer19 = new Pointer();
                JCudnn.cudnnGetRNNLinLayerBiasParams(handle, cudnnrnndescriptor, i3, cudnntensordescriptorArr[0], cudnnfilterdescriptor, pointer14, i4, cudnnfilterdescriptor4, pointer19);
                JCudnn.cudnnGetFilterNdDescriptor(cudnnfilterdescriptor4, 3, iArr4, iArr5, iArr6, iArr7);
                initGPUData(pointer19, iArr7[0] * iArr7[1] * iArr7[2], 1.0f);
                JCudnn.cudnnDestroyFilterDescriptor(cudnnfilterdescriptor4);
            }
            i3++;
        }
        JCuda.cudaDeviceSynchronize();
        cudaEvent_t cudaevent_t = new cudaEvent_t();
        cudaEvent_t cudaevent_t2 = new cudaEvent_t();
        float[] fArr = {0.0f};
        float[] fArr2 = {0.0f};
        float[] fArr3 = {0.0f};
        JCuda.cudaEventCreate(cudaevent_t);
        JCuda.cudaEventCreate(cudaevent_t2);
        JCuda.cudaEventRecord(cudaevent_t, (cudaStream_t) null);
        System.out.println(j3 + ":" + j4 + ":" + parseInt);
        System.out.println(cudnnrnndescriptor);
        handle(JCudnn.cudnnRNNForwardTraining(handle, cudnnrnndescriptor, parseInt, cudnntensordescriptorArr, pointer, cudnntensordescriptor, pointer2, (cudnnTensorDescriptor) null, (Pointer) null, cudnnfilterdescriptor, pointer14, cudnntensordescriptorArr2, pointer7, cudnntensordescriptor3, pointer8, (cudnnTensorDescriptor) null, (Pointer) null, pointer16, j3, pointer17, j4));
        JCuda.cudaEventRecord(cudaevent_t2, (cudaStream_t) null);
        JCuda.cudaEventSynchronize(cudaevent_t2);
        JCuda.cudaEventElapsedTime(fArr, cudaevent_t, cudaevent_t2);
        float f = fArr[0];
        JCuda.cudaEventRecord(cudaevent_t, (cudaStream_t) null);
        JCudnn.cudnnRNNBackwardData(handle, cudnnrnndescriptor, parseInt, cudnntensordescriptorArr2, pointer7, cudnntensordescriptorArr4, pointer10, cudnntensordescriptor7, pointer11, cudnntensordescriptor8, pointer12, cudnnfilterdescriptor, pointer14, cudnntensordescriptor, pointer2, cudnntensordescriptor2, pointer3, cudnntensordescriptorArr3, pointer4, cudnntensordescriptor5, pointer5, cudnntensordescriptor6, pointer6, pointer16, j3, pointer17, j4);
        JCuda.cudaEventRecord(cudaevent_t2, (cudaStream_t) null);
        JCuda.cudaEventSynchronize(cudaevent_t2);
        JCuda.cudaEventElapsedTime(fArr2, cudaevent_t, cudaevent_t2);
        float f2 = fArr2[0];
        JCuda.cudaEventRecord(cudaevent_t, (cudaStream_t) null);
        JCuda.cudaMemset(pointer15, 0, j2);
        JCudnn.cudnnRNNBackwardWeights(handle, cudnnrnndescriptor, parseInt, cudnntensordescriptorArr, pointer, cudnntensordescriptor, pointer2, cudnntensordescriptorArr4, pointer10, pointer16, j3, cudnnfilterdescriptor2, pointer15, pointer17, j4);
        JCuda.cudaEventRecord(cudaevent_t2, (cudaStream_t) null);
        JCuda.cudaEventSynchronize(cudaevent_t2);
        JCuda.cudaEventElapsedTime(fArr3, cudaevent_t, cudaevent_t2);
        float f3 = fArr3[0];
        int i5 = 0;
        if (0 == 0 || 0 == 1) {
            i5 = 2;
        } else if (0 == 2) {
            i5 = 8;
        } else if (0 == 3) {
            i5 = 6;
        }
        PrintStream printStream = System.out;
        Object[] objArr = new Object[1];
        objArr[0] = Double.valueOf((((((((i5 * 2.0d) * (0 != 0 ? 2 : 1)) * parseInt3) * parseInt3) * parseInt) * parseInt5) * parseInt2) / (1000000.0d * f));
        printStream.printf("Forward: %3.0f GFLOPS\n", objArr);
        PrintStream printStream2 = System.out;
        Object[] objArr2 = new Object[1];
        objArr2[0] = Double.valueOf((((((((i5 * 4.0d) * (0 != 0 ? 2 : 1)) * parseInt3) * parseInt3) * parseInt) * parseInt5) * parseInt2) / (1000000.0d * (f2 + f3)));
        printStream2.printf("Backward: %3.0f GFLOPS, ", objArr2);
        PrintStream printStream3 = System.out;
        Object[] objArr3 = new Object[1];
        objArr3[0] = Double.valueOf((((((((i5 * 2.0d) * (0 != 0 ? 2 : 1)) * parseInt3) * parseInt3) * parseInt) * parseInt5) * parseInt2) / (1000000.0d * f2));
        printStream3.printf("(%3.0f GFLOPS), ", objArr3);
        PrintStream printStream4 = System.out;
        Object[] objArr4 = new Object[1];
        objArr4[0] = Double.valueOf((((((((i5 * 2.0d) * (0 != 0 ? 2 : 1)) * parseInt3) * parseInt3) * parseInt) * parseInt5) * parseInt2) / (1000000.0d * f3));
        printStream4.printf("(%3.0f GFLOPS)\n", objArr4);
        Object[] objArr5 = new Object[1];
        objArr5[0] = Double.valueOf((((((((i5 * 2.0d) * (0 != 0 ? 2 : 1)) * parseInt3) * parseInt3) * parseInt) * parseInt5) * parseInt2) / (1000000.0d * f));
        printWriter.printf("Forward: %3.0f GFLOPS\n", objArr5);
        Object[] objArr6 = new Object[1];
        objArr6[0] = Double.valueOf((((((((i5 * 4.0d) * (0 != 0 ? 2 : 1)) * parseInt3) * parseInt3) * parseInt) * parseInt5) * parseInt2) / (1000000.0d * (f2 + f3)));
        printWriter.printf("Backward: %3.0f GFLOPS, ", objArr6);
        Object[] objArr7 = new Object[1];
        objArr7[0] = Double.valueOf((((((((i5 * 2.0d) * (0 != 0 ? 2 : 1)) * parseInt3) * parseInt3) * parseInt) * parseInt5) * parseInt2) / (1000000.0d * f2));
        printWriter.printf("(%3.0f GFLOPS), ", objArr7);
        Object[] objArr8 = new Object[1];
        objArr8[0] = Double.valueOf((((((((i5 * 2.0d) * (0 != 0 ? 2 : 1)) * parseInt3) * parseInt3) * parseInt) * parseInt5) * parseInt2) / (1000000.0d * f3));
        printWriter.printf("(%3.0f GFLOPS)\n", objArr8);
        JCuda.cudaDeviceSynchronize();
        int i6 = 0 != 0 ? 2 : 1;
        float[] fArr4 = new float[parseInt3 * parseInt * parseInt5 * i6];
        float[] fArr5 = new float[parseInt3 * parseInt5 * parseInt2 * i6];
        float[] fArr6 = new float[parseInt3 * parseInt5 * parseInt2 * i6];
        JCuda.cudaMemcpy(Pointer.to(fArr4), pointer7, parseInt3 * parseInt * parseInt5 * i6 * 4, 2);
        JCuda.cudaMemcpy(Pointer.to(fArr5), pointer8, parseInt2 * parseInt3 * parseInt5 * i6 * 4, 2);
        if (0 == 2) {
            JCuda.cudaMemcpy(Pointer.to(fArr6), pointer9, parseInt2 * parseInt3 * parseInt5 * i6 * 4, 2);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i7 = 0; i7 < parseInt5; i7++) {
            double d4 = 0.0d;
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i8 = 0; i8 < parseInt; i8++) {
                for (int i9 = 0; i9 < parseInt3 * i6; i9++) {
                    d4 += fArr4[(i8 * parseInt5 * parseInt3 * i6) + (i7 * parseInt3 * i6) + i9];
                }
            }
            for (int i10 = 0; i10 < parseInt2 * i6; i10++) {
                for (int i11 = 0; i11 < parseInt3; i11++) {
                    d5 += fArr5[(i10 * parseInt3 * parseInt5) + (i7 * parseInt3) + i11];
                    if (0 == 2) {
                        d6 += fArr6[(i10 * parseInt3 * parseInt5) + (i7 * parseInt3) + i11];
                    }
                }
            }
            d += d4;
            d2 += d5;
            d3 += d6;
        }
        System.out.printf("i checksum %E     ", Double.valueOf(d));
        printWriter.printf("i checksum %E     ", Double.valueOf(d));
        if (0 == 2) {
            System.out.printf("c checksum %E     ", Double.valueOf(d3));
            printWriter.printf("c checksum %E     ", Double.valueOf(d3));
        }
        System.out.printf("h checksum %E\n", Double.valueOf(d2));
        printWriter.printf("h checksum %E\n", Double.valueOf(d2));
        int i12 = 0 != 0 ? 2 : 1;
        float[] fArr7 = new float[parseInt4 * parseInt * parseInt5];
        float[] fArr8 = new float[parseInt3 * parseInt5 * parseInt2 * i12];
        float[] fArr9 = new float[parseInt3 * parseInt5 * parseInt2 * i12];
        JCuda.cudaMemcpy(Pointer.to(fArr7), pointer4, parseInt * parseInt5 * parseInt4 * 4, 2);
        JCuda.cudaMemcpy(Pointer.to(fArr8), pointer5, parseInt2 * parseInt3 * parseInt5 * i12 * 4, 2);
        if (0 == 2) {
            JCuda.cudaMemcpy(Pointer.to(fArr9), pointer6, parseInt2 * parseInt3 * parseInt5 * i12 * 4, 2);
        }
        float f4 = 0.0f;
        float f5 = 0.0f;
        float f6 = 0.0f;
        for (int i13 = 0; i13 < parseInt5; i13++) {
            double d7 = 0.0d;
            double d8 = 0.0d;
            double d9 = 0.0d;
            for (int i14 = 0; i14 < parseInt; i14++) {
                for (int i15 = 0; i15 < parseInt4; i15++) {
                    d7 += fArr7[(i14 * parseInt5 * parseInt4) + (i13 * parseInt4) + i15];
                }
            }
            for (int i16 = 0; i16 < parseInt2 * i12; i16++) {
                for (int i17 = 0; i17 < parseInt3; i17++) {
                    d8 += fArr8[(i16 * parseInt3 * parseInt5) + (i13 * parseInt3) + i17];
                    if (0 == 2) {
                        d9 += fArr9[(i16 * parseInt3 * parseInt5) + (i13 * parseInt3) + i17];
                    }
                }
            }
            f4 = (float) (f4 + d7);
            f5 = (float) (f5 + d8);
            f6 = (float) (f6 + d9);
        }
        System.out.printf("di checksum %E    ", Float.valueOf(f4));
        printWriter.printf("di checksum %E    ", Float.valueOf(f4));
        if (0 == 2) {
            System.out.printf("dc checksum %E    ", Float.valueOf(f6));
            printWriter.printf("dc checksum %E    ", Float.valueOf(f6));
        }
        System.out.printf("dh checksum %E\n", Float.valueOf(f5));
        printWriter.printf("dh checksum %E\n", Float.valueOf(f5));
        JCuda.cudaMemcpy(Pointer.to(new float[(int) (j2 / 4)]), pointer15, j2, 2);
        double d10 = 0.0d;
        for (int i18 = 0; i18 < j2 / 4; i18++) {
            d10 += r0[i18];
        }
        System.out.printf("dw checksum %E\n", Double.valueOf(d10));
        printWriter.printf("dw checksum %E\n", Double.valueOf(d10));
        printWriter.flush();
        printWriter.close();
        JCuda.cudaFree(pointer);
        JCuda.cudaFree(pointer2);
        JCuda.cudaFree(pointer3);
        JCuda.cudaFree(pointer7);
        JCuda.cudaFree(pointer8);
        JCuda.cudaFree(pointer9);
        JCuda.cudaFree(pointer4);
        JCuda.cudaFree(pointer5);
        JCuda.cudaFree(pointer6);
        JCuda.cudaFree(pointer10);
        JCuda.cudaFree(pointer11);
        JCuda.cudaFree(pointer12);
        JCuda.cudaFree(pointer16);
        JCuda.cudaFree(pointer17);
        JCuda.cudaFree(pointer14);
        JCuda.cudaFree(pointer15);
        for (int i19 = 0; i19 < parseInt; i19++) {
            JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptorArr[i19]);
            JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptorArr2[i19]);
            JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptorArr3[i19]);
            JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptorArr4[i19]);
        }
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor2);
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor3);
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor4);
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor5);
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor6);
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor7);
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor8);
        JCudnn.cudnnDestroyRNNDescriptor(cudnnrnndescriptor);
        JCudnn.cudnnDestroyFilterDescriptor(cudnnfilterdescriptor);
        JCudnn.cudnnDestroyFilterDescriptor(cudnnfilterdescriptor2);
        JCudnn.cudnnDestroy(handle);
    }

    private static void initGPUData(Pointer pointer, int i, float f) {
        float[] fArr = new float[i];
        Arrays.fill(fArr, f);
        JCuda.cudaMemcpy(pointer, Pointer.to(fArr), i * 4, 1);
    }

    public static void handle(int i) {
        if (i != 0) {
            System.err.println(cudnnStatus.stringFor(i));
            throw new RuntimeException(cudnnStatus.stringFor(i));
        }
    }
}
