package com.omega.engine.gpu.cudnn;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAModules;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnConvolutionBwdDataAlgoPerf;
import jcuda.jcudnn.cudnnConvolutionBwdFilterAlgoPerf;
import jcuda.jcudnn.cudnnConvolutionDescriptor;
import jcuda.jcudnn.cudnnConvolutionFwdAlgoPerf;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;

/* loaded from: input_file:com/omega/engine/gpu/cudnn/TestJCudnn.class */
public class TestJCudnn {
    private static cudnnHandle cudnnHandle;

    public static void main(String[] strArr) {
        CUDAModules.initContext();
        System.out.printf("cudnnGetVersion() : %d , CUDNN_VERSION from cudnn.h : %d\n", Integer.valueOf((int) JCudnn.cudnnGetVersion()), 8401);
        GpuHandle(0);
        conv_transpose();
    }

    public static void conv_transpose() {
        int[] iArr = {0, 0};
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        cudnnFilterDescriptor cudnnfilterdescriptor = new cudnnFilterDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor2 = new cudnnTensorDescriptor();
        cudnnConvolutionDescriptor cudnnconvolutiondescriptor = new cudnnConvolutionDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor2);
        JCudnn.cudnnCreateConvolutionDescriptor(cudnnconvolutiondescriptor);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, 0, 1, 3, 1, 1);
        JCudnn.cudnnSetFilterNdDescriptor(cudnnfilterdescriptor, 0, 0, 4, new int[]{3, 2, 4, 4});
        JCudnn.cudnnSetConvolutionNdDescriptor(cudnnconvolutiondescriptor, 2, iArr, new int[]{1, 1}, new int[]{1, 1}, 1, 0);
        int i = new int[]{1, 3, 1, 1}[0];
        int i2 = (((1 - 1) * 1) - (2 * iArr[0])) + (1 * (4 - 1)) + 0 + 1;
        int i3 = (((1 - 1) * 1) - (2 * iArr[1])) + (1 * (4 - 1)) + 0 + 1;
        System.out.println(i + ":2:" + i2 + ":" + i3);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor2, 0, 0, i, 2, i2, i3);
        int forwardAlgorithm = getForwardAlgorithm(-1, cudnntensordescriptor2, cudnnfilterdescriptor, cudnnconvolutiondescriptor, cudnntensordescriptor);
        int bkfgo = getBKFGO(-1, cudnntensordescriptor2, cudnntensordescriptor, cudnnfilterdescriptor, cudnnconvolutiondescriptor);
        int bkdgo = getBKDGO(-1, cudnntensordescriptor2, cudnntensordescriptor, cudnnfilterdescriptor, cudnnconvolutiondescriptor);
        Pointer pointer = new Pointer();
        long convTransposeWorkSpace = getConvTransposeWorkSpace(cudnntensordescriptor, cudnnfilterdescriptor, cudnnconvolutiondescriptor, cudnntensordescriptor2, forwardAlgorithm, bkfgo, bkdgo);
        if (convTransposeWorkSpace != 0) {
            handle(JCuda.cudaMalloc(pointer, convTransposeWorkSpace));
        }
        System.out.println(convTransposeWorkSpace);
        Pointer pointer2 = Pointer.to(new float[]{1.0f});
        Pointer pointer3 = Pointer.to(new float[]{0.0f});
        float[] order = MatrixUtils.order(1 * 3 * 1 * 1, 1, 1);
        float[] order2 = MatrixUtils.order(i * 2 * i2 * i3, 1, 1);
        float[] order3 = MatrixUtils.order(2 * 3 * 4 * 4, 1, 1);
        Tensor tensor = new Tensor(1, 3, 1, 1, order, true);
        Tensor tensor2 = new Tensor(2, 3, 4, 4, order3, true);
        Tensor tensor3 = new Tensor(i, 2, i2, i3, true);
        Tensor tensor4 = new Tensor(i, 2, i2, i3, order2, true);
        Tensor tensor5 = new Tensor(2, 3, 4, 4, true);
        Tensor tensor6 = new Tensor(1, 3, 1, 1, true);
        handle(JCudnn.cudnnConvolutionBackwardData(cudnnHandle, pointer2, cudnnfilterdescriptor, tensor2.getGpuData(), cudnntensordescriptor, tensor.getGpuData(), cudnnconvolutiondescriptor, bkdgo, pointer, convTransposeWorkSpace, pointer3, cudnntensordescriptor2, tensor3.getGpuData()));
        handle(JCudnn.cudnnConvolutionForward(cudnnHandle, pointer2, cudnntensordescriptor2, tensor4.getGpuData(), cudnnfilterdescriptor, tensor2.getGpuData(), cudnnconvolutiondescriptor, forwardAlgorithm, pointer, convTransposeWorkSpace, pointer3, cudnntensordescriptor, tensor6.getGpuData()));
        handle(JCudnn.cudnnConvolutionBackwardFilter(cudnnHandle, pointer2, cudnntensordescriptor2, tensor4.getGpuData(), cudnntensordescriptor, tensor.getGpuData(), cudnnconvolutiondescriptor, bkfgo, pointer, convTransposeWorkSpace, pointer3, cudnnfilterdescriptor, tensor5.getGpuData()));
        tensor.showDM();
        tensor2.showDM();
        System.out.println("output:");
        tensor3.showDM();
        System.out.println("delta:");
        tensor4.showDM();
        System.out.println("dx:");
        tensor6.showDM();
        System.out.println("dw:");
        tensor5.showDM();
        if (convTransposeWorkSpace != 0) {
            JCuda.cudaFree(pointer);
        }
    }

    public static void conv() {
        int[] iArr = {128, 64, 3, 3};
        int[] iArr2 = {128, 64, 32, 32};
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        cudnnFilterDescriptor cudnnfilterdescriptor = new cudnnFilterDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor2 = new cudnnTensorDescriptor();
        cudnnConvolutionDescriptor cudnnconvolutiondescriptor = new cudnnConvolutionDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor3 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor4 = new cudnnTensorDescriptor();
        cudnnFilterDescriptor cudnnfilterdescriptor2 = new cudnnFilterDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor2);
        JCudnn.cudnnCreateConvolutionDescriptor(cudnnconvolutiondescriptor);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor3);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor4);
        JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor2);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, 0, 128, 64, 32, 32);
        JCudnn.cudnnSetFilterNdDescriptor(cudnnfilterdescriptor, 0, 0, 4, iArr);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor3, 0, 0, 128, 64, 32, 32);
        JCudnn.cudnnSetFilterNdDescriptor(cudnnfilterdescriptor2, 0, 0, 4, iArr);
        JCudnn.cudnnSetConvolutionNdDescriptor(cudnnconvolutiondescriptor, 2, new int[]{2, 2}, new int[]{1, 1}, new int[]{1, 1}, 1, 0);
        handle(JCudnn.cudnnGetConvolutionNdForwardOutputDim(cudnnconvolutiondescriptor, cudnntensordescriptor, cudnnfilterdescriptor, 4, iArr2));
        int i = iArr2[0];
        int i2 = iArr2[1];
        int i3 = iArr2[2];
        int i4 = iArr2[3];
        System.out.println(i + ":" + i2 + ":" + i3 + ":" + i4);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor2, 0, 0, i, i2, i3, i4);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor4, 0, 0, i, i2, i3, i4);
        int forwardAlgorithm = getForwardAlgorithm(-1, cudnntensordescriptor, cudnnfilterdescriptor, cudnnconvolutiondescriptor, cudnntensordescriptor2);
        int bkfgo = getBKFGO(-1, cudnntensordescriptor, cudnntensordescriptor4, cudnnfilterdescriptor2, cudnnconvolutiondescriptor);
        int bkdgo = getBKDGO(-1, cudnntensordescriptor3, cudnntensordescriptor4, cudnnfilterdescriptor, cudnnconvolutiondescriptor);
        Pointer pointer = new Pointer();
        long workSpace = getWorkSpace(cudnntensordescriptor, cudnntensordescriptor4, cudnnfilterdescriptor, cudnnfilterdescriptor2, cudnnconvolutiondescriptor, cudnntensordescriptor2, cudnntensordescriptor3, forwardAlgorithm, bkfgo, bkdgo);
        if (workSpace != 0) {
            JCuda.cudaMalloc(pointer, workSpace);
        }
        System.out.println(workSpace);
        Pointer pointer2 = Pointer.to(new float[]{1.0f});
        Pointer pointer3 = Pointer.to(new float[]{0.0f});
        float[] gaussianRandom = RandomUtils.gaussianRandom(128 * 64 * 32 * 32, 0.1f);
        float[] gaussianRandom2 = RandomUtils.gaussianRandom(i * i2 * i3 * i4, 0.1f);
        float[] gaussianRandom3 = RandomUtils.gaussianRandom(128 * 64 * 3 * 3, 0.1f);
        Tensor tensor = new Tensor(128, 64, 32, 32, gaussianRandom, true);
        Tensor tensor2 = new Tensor(128, 64, 3, 3, gaussianRandom3, true);
        Tensor tensor3 = new Tensor(i, i2, i3, i4, true);
        Tensor tensor4 = new Tensor(i, i2, i3, i4, gaussianRandom2, true);
        Tensor tensor5 = new Tensor(128, 64, 3, 3, true);
        Tensor tensor6 = new Tensor(128, 64, 32, 32, true);
        handle(JCudnn.cudnnConvolutionForward(cudnnHandle, pointer2, cudnntensordescriptor, tensor.getGpuData(), cudnnfilterdescriptor, tensor2.getGpuData(), cudnnconvolutiondescriptor, forwardAlgorithm, pointer, workSpace, pointer3, cudnntensordescriptor2, tensor3.getGpuData()));
        handle(JCudnn.cudnnConvolutionBackwardFilter(cudnnHandle, pointer2, cudnntensordescriptor, tensor.getGpuData(), cudnntensordescriptor4, tensor4.getGpuData(), cudnnconvolutiondescriptor, bkfgo, pointer, workSpace, pointer3, cudnnfilterdescriptor2, tensor5.getGpuData()));
        handle(JCudnn.cudnnConvolutionBackwardData(cudnnHandle, pointer2, cudnnfilterdescriptor, tensor2.getGpuData(), cudnntensordescriptor4, tensor4.getGpuData(), cudnnconvolutiondescriptor, bkdgo, pointer, workSpace, pointer3, cudnntensordescriptor3, tensor6.getGpuData()));
        if (workSpace != 0) {
            JCuda.cudaFree(pointer);
        }
    }

    public static long getWorkSpace(cudnnTensorDescriptor cudnntensordescriptor, cudnnTensorDescriptor cudnntensordescriptor2, cudnnFilterDescriptor cudnnfilterdescriptor, cudnnFilterDescriptor cudnnfilterdescriptor2, cudnnConvolutionDescriptor cudnnconvolutiondescriptor, cudnnTensorDescriptor cudnntensordescriptor3, cudnnTensorDescriptor cudnntensordescriptor4, int i, int i2, int i3) {
        long j = 0;
        long[] jArr = {0};
        System.out.println("fw_algo:" + i);
        if (i != 9999) {
            handle(JCudnn.cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle, cudnntensordescriptor, cudnnfilterdescriptor, cudnnconvolutiondescriptor, cudnntensordescriptor3, i, jArr));
            System.out.println(jArr[0]);
            if (jArr[0] > 0) {
                j = jArr[0];
            }
        }
        System.out.println("bkf_algo:" + i2);
        if (i2 != 9999) {
            handle(JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle, cudnntensordescriptor, cudnntensordescriptor2, cudnnconvolutiondescriptor, cudnnfilterdescriptor2, i2, jArr));
            System.out.println(jArr[0]);
            if (jArr[0] > j) {
                j = jArr[0];
            }
        }
        System.out.println("bkd_algo:" + i3);
        if (i3 != 9999) {
            handle(JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle, cudnnfilterdescriptor, cudnntensordescriptor2, cudnnconvolutiondescriptor, cudnntensordescriptor4, i3, jArr));
            System.out.println(jArr[0]);
            if (jArr[0] > j) {
                j = jArr[0];
            }
        }
        return j;
    }

    public static long getConvTransposeWorkSpace(cudnnTensorDescriptor cudnntensordescriptor, cudnnFilterDescriptor cudnnfilterdescriptor, cudnnConvolutionDescriptor cudnnconvolutiondescriptor, cudnnTensorDescriptor cudnntensordescriptor2, int i, int i2, int i3) {
        long j = 0;
        long[] jArr = {0};
        System.out.println("fw_algo:" + i);
        if (i != 9999) {
            handle(JCudnn.cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle, cudnntensordescriptor2, cudnnfilterdescriptor, cudnnconvolutiondescriptor, cudnntensordescriptor, i, jArr));
            System.out.println(jArr[0]);
            if (jArr[0] > 0) {
                j = jArr[0];
            }
        }
        System.out.println("bkf_algo:" + i2);
        if (i2 != 9999) {
            handle(JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle, cudnntensordescriptor2, cudnntensordescriptor, cudnnconvolutiondescriptor, cudnnfilterdescriptor, i2, jArr));
            System.out.println(jArr[0]);
            if (jArr[0] > j) {
                j = jArr[0];
            }
        }
        System.out.println("bkd_algo:" + i3);
        if (i3 != 9999) {
            handle(JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle, cudnnfilterdescriptor, cudnntensordescriptor, cudnnconvolutiondescriptor, cudnntensordescriptor2, i3, jArr));
            System.out.println(jArr[0]);
            if (jArr[0] > j) {
                j = jArr[0];
            }
        }
        return j;
    }

    public static int getBKDGO(int i, cudnnTensorDescriptor cudnntensordescriptor, cudnnTensorDescriptor cudnntensordescriptor2, cudnnFilterDescriptor cudnnfilterdescriptor, cudnnConvolutionDescriptor cudnnconvolutiondescriptor) {
        int[] iArr = {-1};
        cudnnConvolutionBwdDataAlgoPerf[] cudnnconvolutionbwddataalgoperfArr = new cudnnConvolutionBwdDataAlgoPerf[12];
        System.out.println("Testing cudnnFindConvolutionBackwardDataAlgorithm ...");
        JCudnn.cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle, cudnnfilterdescriptor, cudnntensordescriptor2, cudnnconvolutiondescriptor, cudnntensordescriptor, 6, iArr, cudnnconvolutionbwddataalgoperfArr);
        int i2 = iArr[0];
        for (int i3 = 0; i3 < i2; i3++) {
            System.out.printf("^^^^ for Algo %d: %f time requiring %d memory %s \n", Integer.valueOf(cudnnconvolutionbwddataalgoperfArr[i3].algo), Float.valueOf(cudnnconvolutionbwddataalgoperfArr[i3].time), Long.valueOf(cudnnconvolutionbwddataalgoperfArr[i3].memory), "[" + checkError(cudnnconvolutionbwddataalgoperfArr[i3].status) + "]");
        }
        return cudnnconvolutionbwddataalgoperfArr[0].algo;
    }

    public static int getBKFGO(int i, cudnnTensorDescriptor cudnntensordescriptor, cudnnTensorDescriptor cudnntensordescriptor2, cudnnFilterDescriptor cudnnfilterdescriptor, cudnnConvolutionDescriptor cudnnconvolutiondescriptor) {
        int[] iArr = {-1};
        cudnnConvolutionBwdFilterAlgoPerf[] cudnnconvolutionbwdfilteralgoperfArr = new cudnnConvolutionBwdFilterAlgoPerf[12];
        System.out.println("Testing cudnnFindConvolutionBackwardFilterAlgorithm ...");
        JCudnn.cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle, cudnntensordescriptor, cudnntensordescriptor2, cudnnconvolutiondescriptor, cudnnfilterdescriptor, 6, iArr, cudnnconvolutionbwdfilteralgoperfArr);
        int i2 = iArr[0];
        for (int i3 = 0; i3 < i2; i3++) {
            System.out.printf("^^^^ for Algo %d: %f time requiring %d memory %s \n", Integer.valueOf(cudnnconvolutionbwdfilteralgoperfArr[i3].algo), Float.valueOf(cudnnconvolutionbwdfilteralgoperfArr[i3].time), Long.valueOf(cudnnconvolutionbwdfilteralgoperfArr[i3].memory), "[" + checkError(cudnnconvolutionbwdfilteralgoperfArr[i3].status) + "]");
        }
        return cudnnconvolutionbwdfilteralgoperfArr[0].algo;
    }

    public static int getForwardAlgorithm(int i, cudnnTensorDescriptor cudnntensordescriptor, cudnnFilterDescriptor cudnnfilterdescriptor, cudnnConvolutionDescriptor cudnnconvolutiondescriptor, cudnnTensorDescriptor cudnntensordescriptor2) {
        if (i >= 0) {
            return i;
        }
        int[] iArr = {-1};
        cudnnConvolutionFwdAlgoPerf[] cudnnconvolutionfwdalgoperfArr = new cudnnConvolutionFwdAlgoPerf[16];
        System.out.println("Testing cudnnFindConvolutionForwardAlgorithm ...");
        JCudnn.cudnnFindConvolutionForwardAlgorithm(cudnnHandle, cudnntensordescriptor, cudnnfilterdescriptor, cudnnconvolutiondescriptor, cudnntensordescriptor2, 8, iArr, cudnnconvolutionfwdalgoperfArr);
        int i2 = iArr[0];
        for (int i3 = 0; i3 < i2; i3++) {
            System.out.printf("^^^^ for Algo %d: %f time requiring %d memory %s \n", Integer.valueOf(cudnnconvolutionfwdalgoperfArr[i3].algo), Float.valueOf(cudnnconvolutionfwdalgoperfArr[i3].time), Long.valueOf(cudnnconvolutionfwdalgoperfArr[i3].memory), "[" + checkError(cudnnconvolutionfwdalgoperfArr[i3].status) + "]");
        }
        return cudnnconvolutionfwdalgoperfArr[0].algo;
    }

    public static void testBN2D() {
        Pointer pointer = Pointer.to(new float[]{1.0f});
        Pointer pointer2 = Pointer.to(new float[]{0.0f});
        float[] fArr = {0.9827f, 0.5268f, 0.4057f, 0.2853f, 0.1708f, 0.4791f, 0.5626f, 0.129f, 0.954f, 0.7471f, 0.5806f, 0.8789f, 0.9766f, 0.8142f, 0.9557f, 0.2814f, 0.7667f, 0.5963f, 0.0016f, 0.5944f, 0.4617f, 0.0975f, 0.3558f, 0.3318f, 0.5196f, 0.7558f, 0.7438f, 0.4061f, 0.2737f, 0.1826f, 0.76f, 0.3608f, 0.3924f, 0.2537f, 0.7536f, 0.798f, 0.5246f, 0.6428f, 0.0571f, 0.9973f, 0.7106f, 0.5854f, 0.3122f, 0.2741f, 0.2868f, 0.4628f, 0.2696f, 0.0436f, 0.1222f, 0.4933f, 0.5372f, 0.4992f, 0.2837f, 0.8462f, 0.2095f, 0.1916f, 0.183f, 0.1934f, 0.8305f, 0.0776f, 0.9014f, 0.1835f, 0.7673f, 0.0999f, 0.5783f, 0.7816f, 0.2961f, 0.923f, 0.3454f, 0.603f, 0.4821f, 0.0113f, 0.9629f, 0.8698f, 0.844f, 0.9763f, 0.7661f, 0.2085f, 0.4248f, 0.7407f, 0.5092f, 0.5272f, 0.8521f, 0.1649f, 0.9759f, 0.9084f, 0.3206f, 0.3061f, 0.9648f, 0.3377f, 0.6753f, 0.6662f, 0.457f, 0.9556f, 0.0918f, 0.8788f, 0.6432f, 0.4928f, 0.8778f, 0.5665f, 0.7979f, 0.5639f, 0.597f, 0.4987f, 0.1227f, 0.4963f, 0.6865f, 0.5728f, 0.1927f, 0.1199f, 0.5015f, 0.0221f, 0.0826f, 0.0077f, 0.0568f, 0.7569f, 0.7684f, 0.1536f, 0.4406f, 0.2919f, 0.3006f, 0.9501f, 0.1994f, 0.3314f, 0.5612f, 0.3303f, 0.8773f, 0.3262f, 0.1926f, 0.8667f, 0.336f, 0.5357f, 0.3332f, 0.2044f, 0.5538f, 0.0607f, 0.2203f, 0.7994f, 0.6357f, 0.6469f, 0.8163f, 0.7764f, 0.6821f, 0.6798f, 0.0553f, 0.0609f, 0.2305f, 0.7183f, 0.8135f, 0.7688f};
        Tensor tensor = new Tensor(2, 3, 5, 5, fArr, true);
        Tensor tensor2 = new Tensor(2, 3, 5, 5, true);
        Tensor tensor3 = new Tensor(1, 3, 1, 1, new float[]{1.0f, 1.0f, 1.0f}, true);
        Tensor tensor4 = new Tensor(1, 3, 1, 1, true);
        Tensor tensor5 = new Tensor(1, 3, 1, 1, true);
        Tensor tensor6 = new Tensor(1, 3, 1, 1, true);
        Tensor tensor7 = new Tensor(1, 3, 1, 1, true);
        Tensor tensor8 = new Tensor(1, 3, 1, 1, true);
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor2 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor3 = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor2);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor3);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, 0, 2, 3, 5, 5);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor2, 0, 0, 2, 3, 5, 5);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor3, 0, 0, 1, 3, 1, 1);
        handle(JCudnn.cudnnBatchNormalizationForwardTraining(cudnnHandle, 2, pointer, pointer2, cudnntensordescriptor, tensor.getGpuData(), cudnntensordescriptor2, tensor2.getGpuData(), cudnntensordescriptor3, tensor3.getGpuData(), tensor4.getGpuData(), 0.8999999761581421d, tensor7.getGpuData(), tensor8.getGpuData(), 1.0E-5d, tensor5.getGpuData(), tensor6.getGpuData()));
        System.out.println("mean:" + JsonUtils.toJson(tensor5.syncHost()));
        System.out.println("var:" + JsonUtils.toJson(tensor6.syncHost()));
        System.out.println("runingMean:" + JsonUtils.toJson(tensor7.syncHost()));
        System.out.println("runingVar:" + JsonUtils.toJson(tensor8.syncHost()));
        System.out.println("output:" + JsonUtils.toJson(tensor2.syncHost()));
        Tensor tensor9 = new Tensor(2, 3, 5, 5, MatrixUtils.one(fArr.length), true);
        Tensor tensor10 = new Tensor(2, 3, 5, 5, true);
        Tensor tensor11 = new Tensor(1, 3, 1, 1, true);
        Tensor tensor12 = new Tensor(1, 3, 1, 1, true);
        cudnnTensorDescriptor cudnntensordescriptor4 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor5 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor6 = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor4);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor5);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor6);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor4, 0, 0, 2, 3, 5, 5);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor5, 0, 0, 2, 3, 5, 5);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor6, 0, 0, 1, 3, 1, 1);
        handle(JCudnn.cudnnBatchNormalizationBackward(cudnnHandle, 2, pointer, pointer2, pointer, pointer2, cudnntensordescriptor, tensor.getGpuData(), cudnntensordescriptor4, tensor9.getGpuData(), cudnntensordescriptor5, tensor10.getGpuData(), cudnntensordescriptor6, tensor3.getGpuData(), tensor11.getGpuData(), tensor12.getGpuData(), 1.0E-5d, tensor5.getGpuData(), tensor6.getGpuData()));
        System.out.println("delta:" + JsonUtils.toJson(tensor9.syncHost()));
        System.out.println("dgamma:" + JsonUtils.toJson(tensor11.syncHost()));
        System.out.println("dbeta:" + JsonUtils.toJson(tensor12.syncHost()));
        System.out.println("dx:" + JsonUtils.toJson(tensor10.syncHost()));
        Tensor tensor13 = new Tensor(2, 3, 5, 5, true);
        handle(JCudnn.cudnnBatchNormalizationForwardInference(cudnnHandle, 2, pointer, pointer2, cudnntensordescriptor, tensor.getGpuData(), cudnntensordescriptor2, tensor13.getGpuData(), cudnntensordescriptor3, tensor3.getGpuData(), tensor4.getGpuData(), tensor7.getGpuData(), tensor8.getGpuData(), 1.0E-5d));
        System.out.println("test-output:" + JsonUtils.toJson(tensor13.syncHost()));
    }

    public static void testBN1D() {
        Pointer pointer = Pointer.to(new float[]{1.0f});
        Pointer pointer2 = Pointer.to(new float[]{0.0f});
        float[] fArr = {56.773f, -7.231f, 39.634f, 24.728f, -17.959f, 55.251f, -52.316f, -36.322f, -29.619f, 55.24f, 26.773f, -1.231f, 19.634f, 4.728f, 7.958f, -65.251f, 52.316f, -36.322f, -23.619f, -5.247f};
        Tensor tensor = new Tensor(2, 1, 1, 10, fArr, true);
        Tensor tensor2 = new Tensor(2, 1, 1, 10, true);
        Tensor tensor3 = new Tensor(1, 1, 1, 10, new float[]{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, true);
        Tensor tensor4 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor5 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor6 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor7 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor8 = new Tensor(1, 1, 1, 10, true);
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor2 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor3 = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor2);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor3);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, 0, 2, 10, 1, 1);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor2, 0, 0, 2, 10, 1, 1);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor3, 0, 0, 1, 10, 1, 1);
        handle(JCudnn.cudnnBatchNormalizationForwardTraining(cudnnHandle, 1, pointer, pointer2, cudnntensordescriptor, tensor.getGpuData(), cudnntensordescriptor2, tensor2.getGpuData(), cudnntensordescriptor3, tensor3.getGpuData(), tensor4.getGpuData(), 0.8999999761581421d, tensor7.getGpuData(), tensor8.getGpuData(), 1.0E-5d, tensor5.getGpuData(), tensor6.getGpuData()));
        System.out.println("mean:" + JsonUtils.toJson(tensor5.syncHost()));
        System.out.println("var:" + JsonUtils.toJson(tensor6.syncHost()));
        System.out.println("runingMean:" + JsonUtils.toJson(tensor7.syncHost()));
        System.out.println("runingVar:" + JsonUtils.toJson(tensor8.syncHost()));
        System.out.println("output:" + JsonUtils.toJson(tensor2.syncHost()));
        Tensor tensor9 = new Tensor(2, 1, 1, 10, MatrixUtils.one(fArr.length), true);
        Tensor tensor10 = new Tensor(2, 1, 1, 10, true);
        Tensor tensor11 = new Tensor(1, 1, 1, 10, true);
        Tensor tensor12 = new Tensor(1, 1, 1, 10, true);
        cudnnTensorDescriptor cudnntensordescriptor4 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor5 = new cudnnTensorDescriptor();
        cudnnTensorDescriptor cudnntensordescriptor6 = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor4);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor5);
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor6);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor4, 0, 0, 2, 10, 1, 1);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor5, 0, 0, 2, 10, 1, 1);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor6, 0, 0, 1, 10, 1, 1);
        handle(JCudnn.cudnnBatchNormalizationBackward(cudnnHandle, 1, pointer, pointer2, pointer, pointer, cudnntensordescriptor, tensor.getGpuData(), cudnntensordescriptor4, tensor9.getGpuData(), cudnntensordescriptor5, tensor10.getGpuData(), cudnntensordescriptor6, tensor3.getGpuData(), tensor11.getGpuData(), tensor12.getGpuData(), 1.0E-5d, tensor5.getGpuData(), tensor6.getGpuData()));
        System.out.println("delta:" + JsonUtils.toJson(tensor9.syncHost()));
        System.out.println("dgamma:" + JsonUtils.toJson(tensor11.syncHost()));
        System.out.println("dbeta:" + JsonUtils.toJson(tensor12.syncHost()));
        System.out.println("dx:" + JsonUtils.toJson(tensor10.syncHost()));
    }

    public static void handle(int i) {
        if (i != 0) {
            System.err.println(cudnnStatus.stringFor(i));
        } else {
            System.out.println("success.");
        }
    }

    public static String checkError(int i) {
        return i != 0 ? cudnnStatus.stringFor(i) : "success";
    }

    public static void GpuHandle(int i) {
        if (0 > i) {
            cudnnHandle = null;
            return;
        }
        initThread();
        cudnnHandle = new cudnnHandle();
        JCudnn.cudnnCreate(cudnnHandle);
    }

    public static void initThread() {
        setDevice(0);
    }

    public static void setDevice(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("cudaDeviceId=" + i);
        }
        if (isThreadDeviceId(i)) {
            return;
        }
        System.out.println(JCuda.cudaSetDevice(i));
    }

    public static boolean isThreadDeviceId(int i) {
        Integer threadDeviceId = getThreadDeviceId();
        return threadDeviceId != null && i == threadDeviceId.intValue();
    }

    public static Integer getThreadDeviceId() {
        return 0;
    }
}
