/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.convolution.subsampling;

import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;

public class CudnnSubsamplingHelper
implements SubsamplingHelper {
    CudnnContext cudnnContext = new CudnnContext();
    int dataType = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 1 : 0;
    int tensorFormat = 0;
    FloatPointer alpha = new FloatPointer(new float[]{1.0f});
    FloatPointer beta = new FloatPointer(new float[]{0.0f});
    INDArray reduced = null;

    static void checkCuda(int error) {
        if (error != 0) {
            throw new RuntimeException("CUDA error = " + error);
        }
    }

    static void checkCudnn(int status) {
        if (status != 0) {
            throw new RuntimeException("cuDNN status = " + status);
        }
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, SubsamplingLayer.PoolingType poolingType) {
        int poolingMode;
        int miniBatch = input.size(0);
        int depth = input.size(1);
        int inH = input.size(2);
        int inW = input.size(3);
        int outH = Convolution.outSize((int)inH, (int)kernel[0], (int)strides[0], (int)pad[0], (boolean)false);
        int outW = Convolution.outSize((int)inW, (int)kernel[1], (int)strides[1], (int)pad[1], (boolean)false);
        DefaultGradient retGradient = new DefaultGradient();
        switch (poolingType) {
            case AVG: {
                poolingMode = 1;
                break;
            }
            case MAX: {
                poolingMode = 0;
                break;
            }
            case NONE: {
                return new Pair((Object)retGradient, (Object)epsilon);
            }
            default: {
                return null;
            }
        }
        if (!Shape.strideDescendingCAscendingF((INDArray)epsilon)) {
            epsilon = epsilon.dup();
        }
        int[] srcStride = input.stride();
        int[] deltaStride = epsilon.stride();
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)outH, (int)outW, (int)deltaStride[0], (int)deltaStride[1], (int)deltaStride[2], (int)deltaStride[3]));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetPooling2dDescriptor((cudnn.cudnnPoolingStruct)this.cudnnContext.poolingDesc, (int)poolingMode, (int)1, (int)kernel[0], (int)kernel[1], (int)pad[0], (int)pad[1], (int)strides[0], (int)strides[1]));
        INDArray outEpsilon = Nd4j.create((int[])new int[]{miniBatch, depth, inH, inW}, (char)'c');
        int[] dstStride = outEpsilon.stride();
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{epsilon, this.reduced, outEpsilon});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer zData = allocator.getPointer(this.reduced, context);
        Pointer dstData = allocator.getPointer(outEpsilon, context);
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnPoolingBackward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnPoolingStruct)this.cudnnContext.poolingDesc, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)zData, (cudnn.cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)epsData, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.registerAction(context, input, new INDArray[]{epsilon, this.reduced, outEpsilon});
        return new Pair((Object)retGradient, (Object)outEpsilon);
    }

    public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, SubsamplingLayer.PoolingType poolingType) {
        int poolingMode;
        int miniBatch = input.size(0);
        int inDepth = input.size(1);
        int inH = input.size(2);
        int inW = input.size(3);
        int outH = Convolution.outSize((int)inH, (int)kernel[0], (int)strides[0], (int)pad[0], (boolean)false);
        int outW = Convolution.outSize((int)inW, (int)kernel[1], (int)strides[1], (int)pad[1], (boolean)false);
        switch (poolingType) {
            case AVG: {
                poolingMode = 1;
                break;
            }
            case MAX: {
                poolingMode = 0;
                break;
            }
            case NONE: {
                return input;
            }
            default: {
                return null;
            }
        }
        int[] srcStride = input.stride();
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetPooling2dDescriptor((cudnn.cudnnPoolingStruct)this.cudnnContext.poolingDesc, (int)poolingMode, (int)1, (int)kernel[0], (int)kernel[1], (int)pad[0], (int)pad[1], (int)strides[0], (int)strides[1]));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        this.reduced = Nd4j.createUninitialized((int[])new int[]{miniBatch, inDepth, outH, outW}, (char)'c');
        int[] dstStride = this.reduced.stride();
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)outH, (int)outW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{this.reduced});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer dstData = allocator.getPointer(this.reduced, context);
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetStream((cudnn.cudnnContext)this.cudnnContext, (cuda.CUstream_st)new cuda.CUstream_st((Pointer)context.getOldStream())));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnPoolingForward((cudnn.cudnnContext)this.cudnnContext, (cudnn.cudnnPoolingStruct)this.cudnnContext.poolingDesc, (Pointer)this.alpha, (cudnn.cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnn.cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.registerAction(context, input, new INDArray[]{this.reduced});
        return this.reduced;
    }

    static class CudnnContext
    extends cudnn.cudnnContext {
        cudnn.cudnnTensorStruct srcTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnTensorStruct dstTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnTensorStruct deltaTensorDesc = new cudnn.cudnnTensorStruct();
        cudnn.cudnnPoolingStruct poolingDesc = new cudnn.cudnnPoolingStruct();

        CudnnContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        CudnnContext(CudnnContext c) {
            super((Pointer)c);
            this.srcTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.poolingDesc = new cudnn.cudnnPoolingStruct((Pointer)c.poolingDesc);
        }

        void createHandles() {
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreate((cudnn.cudnnContext)this));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreatePoolingDescriptor((cudnn.cudnnPoolingStruct)this.poolingDesc));
        }

        void destroyHandles() {
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyPoolingDescriptor((cudnn.cudnnPoolingStruct)this.poolingDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.srcTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.dstTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnn.cudnnTensorStruct)this.deltaTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroy((cudnn.cudnnContext)this));
        }

        static class Deallocator
        extends CudnnContext
        implements Pointer.Deallocator {
            Deallocator(CudnnContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

