/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;

public class CudnnBatchNormalizationHelper
implements BatchNormalizationHelper {
    CudnnContext cudnnContext = new CudnnContext();
    Cache meanCache = new Cache();
    Cache varCache = new Cache();
    int dataType = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 1 : 0;
    int tensorFormat = 0;
    int batchNormMode = 1;
    FloatPointer alpha = new FloatPointer(new float[]{1.0f});
    FloatPointer beta = new FloatPointer(new float[]{0.0f});

    static void checkCuda(int error) {
        if (error != 0) {
            throw new RuntimeException("CUDA error = " + error);
        }
    }

    static void checkCudnn(int status) {
        if (status != 0) {
            throw new RuntimeException("cuDNN status = " + status);
        }
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] shape, INDArray gamma, INDArray dGammaView, INDArray dBetaView, double eps) {
        int miniBatch = input.size(0);
        int depth = input.size(1);
        int inH = input.size(2);
        int inW = input.size(3);
        DefaultGradient retGradient = new DefaultGradient();
        if (!Shape.strideDescendingCAscendingF((INDArray)epsilon)) {
            epsilon = epsilon.dup();
        }
        int[] srcStride = input.stride();
        int[] deltaStride = epsilon.stride();
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)deltaStride[0], (int)deltaStride[1], (int)deltaStride[2], (int)deltaStride[3]));
        INDArray nextEpsilon = Nd4j.createUninitialized((int[])new int[]{miniBatch, depth, inH, inW}, (char)'c');
        int[] dstStride = nextEpsilon.stride();
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        int[] gammaStride = gamma.stride();
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (int)this.tensorFormat, (int)this.dataType, (int)shape[0], (int)shape[1], (int)(shape.length > 2 ? shape[2] : 1), (int)(shape.length > 3 ? shape[3] : 1)));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{epsilon, nextEpsilon, gamma, dGammaView, dBetaView});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer dstData = allocator.getPointer(nextEpsilon, context);
        Pointer gammaData = allocator.getPointer(gamma, context);
        Pointer dGammaData = allocator.getPointer(dGammaView, context);
        Pointer dBetaData = allocator.getPointer(dBetaView, context);
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationBackward((cudnn.cudnnContext)this.cudnnContext, (int)this.batchNormMode, (Pointer)this.alpha, (Pointer)this.beta, (Pointer)this.alpha, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)epsData, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)dGammaData, (Pointer)dBetaData, (double)eps, (Pointer)this.meanCache, (Pointer)this.varCache));
        allocator.registerAction(context, input, new INDArray[]{epsilon, nextEpsilon, gamma, dGammaView, dBetaView});
        retGradient.setGradientFor("gamma", dGammaView);
        retGradient.setGradientFor("beta", dBetaView);
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    public INDArray preOutput(INDArray x, boolean training, int[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps) {
        int miniBatch = x.size(0);
        int inDepth = x.size(1);
        int inH = x.size(2);
        int inW = x.size(3);
        int[] srcStride = x.stride();
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        INDArray activations = Nd4j.createUninitialized((int[])new int[]{miniBatch, inDepth, inH, inW}, (char)'c');
        int[] dstStride = activations.stride();
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        int[] gammaStride = gamma.stride();
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptor((cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (int)this.tensorFormat, (int)this.dataType, (int)shape[0], (int)shape[1], (int)(shape.length > 2 ? shape[2] : 1), (int)(shape.length > 3 ? shape[3] : 1)));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(x, new INDArray[]{activations, gamma, beta, mean, var});
        Pointer srcData = allocator.getPointer(x, context);
        Pointer dstData = allocator.getPointer(activations, context);
        Pointer gammaData = allocator.getPointer(gamma, context);
        Pointer betaData = allocator.getPointer(beta, context);
        Pointer meanData = allocator.getPointer(mean, context);
        Pointer varData = allocator.getPointer(var, context);
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        if (training) {
            if (this.meanCache.capacity() < mean.data().length() * (long)mean.data().getElementSize()) {
                this.meanCache.deallocate();
                this.meanCache = new Cache(mean.data().length() * (long)mean.data().getElementSize());
            }
            if (this.varCache.capacity() < var.data().length() * (long)mean.data().getElementSize()) {
                this.varCache.deallocate();
                this.varCache = new Cache(var.data().length() * (long)mean.data().getElementSize());
            }
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationForwardTraining((cudnn.cudnnContext)this.cudnnContext, (int)this.batchNormMode, (Pointer)this.alpha, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)betaData, (double)decay, (Pointer)meanData, (Pointer)varData, (double)eps, (Pointer)this.meanCache, (Pointer)this.varCache));
        } else {
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationForwardInference((cudnn.cudnnContext)this.cudnnContext, (int)this.batchNormMode, (Pointer)this.alpha, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnn.cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)betaData, (Pointer)meanData, (Pointer)varData, (double)eps));
        }
        allocator.registerAction(context, x, new INDArray[]{activations, gamma, beta, mean, var});
        return activations;
    }

    static class Cache
    extends Pointer {
        Cache() {
        }

        Cache(long size) {
            CudnnBatchNormalizationHelper.checkCuda(cuda.cudaMalloc((Pointer)this, (long)size));
            this.limit = this.capacity = size;
            this.deallocator(new Deallocator(this));
        }

        Cache(Cache c) {
            super((Pointer)c);
        }

        static class Deallocator
        extends Cache
        implements Pointer.Deallocator {
            Deallocator(Cache c) {
                super(c);
            }

            public void deallocate() {
                CudnnBatchNormalizationHelper.checkCuda(cuda.cudaFree((Pointer)this));
                this.setNull();
            }
        }
    }

    static class CudnnContext
    extends cudnn.cudnnContext {
        cudnn.cudnnTensorStruct srcTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnTensorStruct dstTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnTensorStruct deltaTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnTensorStruct gammaBetaTensorDesc = new cudnn.cudnnTensorStruct();

        CudnnContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        CudnnContext(CudnnContext c) {
            super((Pointer)c);
            this.srcTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.gammaBetaTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.gammaBetaTensorDesc);
        }

        void createHandles() {
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreate((cudnn.cudnnContext)this));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.gammaBetaTensorDesc));
        }

        void destroyHandles() {
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.gammaBetaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroy((cudnn.cudnnContext)this));
        }

        static class Deallocator
        extends CudnnContext
        implements Pointer.Deallocator {
            Deallocator(CudnnContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

