/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;

public class CudnnLocalResponseNormalizationHelper
implements LocalResponseNormalizationHelper {
    CudnnContext cudnnContext = new CudnnContext();
    int dataType = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 1 : 0;
    int tensorFormat = 0;
    FloatPointer alpha = new FloatPointer(new float[]{1.0f});
    FloatPointer beta = new FloatPointer(new float[]{0.0f});
    INDArray activations = null;

    static void checkCuda(int error) {
        if (error != 0) {
            throw new RuntimeException("CUDA error = " + error);
        }
    }

    static void checkCudnn(int status) {
        if (status != 0) {
            throw new RuntimeException("cuDNN status = " + status);
        }
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, double k, double n, double alpha, double beta) {
        int miniBatch = input.size(0);
        int depth = input.size(1);
        int inH = input.size(2);
        int inW = input.size(3);
        DefaultGradient retGradient = new DefaultGradient();
        if (!Shape.strideDescendingCAscendingF((INDArray)epsilon)) {
            epsilon = epsilon.dup();
        }
        int[] srcStride = input.stride();
        int[] deltaStride = epsilon.stride();
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)deltaStride[0], (int)deltaStride[1], (int)deltaStride[2], (int)deltaStride[3]));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetLRNDescriptor((cudnn.cudnnLRNStruct)this.cudnnContext.lrnDesc, (int)((int)n), (double)alpha, (double)beta, (double)k));
        INDArray nextEpsilon = Nd4j.createUninitialized((int[])new int[]{miniBatch, depth, inH, inW}, (char)'c');
        int[] dstStride = nextEpsilon.stride();
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{epsilon, this.activations, nextEpsilon});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer zData = allocator.getPointer(this.activations, context);
        Pointer dstData = allocator.getPointer(nextEpsilon, context);
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnLRNCrossChannelBackward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnLRNStruct)this.cudnnContext.lrnDesc, (int)0, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)zData, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)epsData, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.registerAction(context, input, new INDArray[]{epsilon, this.activations, nextEpsilon});
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    public INDArray activate(INDArray input, boolean training, double k, double n, double alpha, double beta) {
        int miniBatch = input.size(0);
        int inDepth = input.size(1);
        int inH = input.size(2);
        int inW = input.size(3);
        int[] srcStride = input.stride();
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        this.activations = Nd4j.createUninitialized((int[])new int[]{miniBatch, inDepth, inH, inW}, (char)'c');
        int[] dstStride = this.activations.stride();
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetLRNDescriptor((cudnn.cudnnLRNStruct)this.cudnnContext.lrnDesc, (int)((int)n), (double)alpha, (double)beta, (double)k));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{this.activations});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer dstData = allocator.getPointer(this.activations, context);
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnLRNCrossChannelForward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnLRNStruct)this.cudnnContext.lrnDesc, (int)0, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.registerAction(context, input, new INDArray[]{this.activations});
        return this.activations;
    }

    static class CudnnContext
    extends cudnn.cudnnContext {
        cudnn.cudnnTensorStruct srcTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnTensorStruct dstTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnTensorStruct deltaTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnLRNStruct lrnDesc = new cudnn.cudnnLRNStruct();

        CudnnContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        CudnnContext(CudnnContext c) {
            super((Pointer)c);
            this.srcTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.lrnDesc = new cudnn.cudnnLRNStruct((Pointer)c.lrnDesc);
        }

        void createHandles() {
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreate((cudnn.cudnnContext)this));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateLRNDescriptor((cudnn.cudnnLRNStruct)this.lrnDesc));
        }

        void destroyHandles() {
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyLRNDescriptor((cudnn.cudnnLRNStruct)this.lrnDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroy((cudnn.cudnnContext)this));
        }

        static class Deallocator
        extends CudnnContext
        implements Pointer.Deallocator {
            Deallocator(CudnnContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

