package org.deeplearning4j.nn.layers.convolution.subsampling;

import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.class */
public class CudnnSubsamplingHelper implements SubsamplingHelper {
    CudnnContext cudnnContext = new CudnnContext();
    int dataType;
    int tensorFormat;
    Pointer alpha;
    Pointer beta;
    INDArray reduced;

    /* renamed from: org.deeplearning4j.nn.layers.convolution.subsampling.CudnnSubsamplingHelper$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$layers$SubsamplingLayer$PoolingType = new int[SubsamplingLayer.PoolingType.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$SubsamplingLayer$PoolingType[SubsamplingLayer.PoolingType.AVG.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$SubsamplingLayer$PoolingType[SubsamplingLayer.PoolingType.MAX.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$SubsamplingLayer$PoolingType[SubsamplingLayer.PoolingType.NONE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper$CudnnContext.class */
    static class CudnnContext extends cudnn.cudnnContext {
        cudnn.cudnnTensorStruct srcTensorDesc;
        cudnn.cudnnTensorStruct dstTensorDesc;
        cudnn.cudnnTensorStruct deltaTensorDesc;
        cudnn.cudnnPoolingStruct poolingDesc;

        /* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper$CudnnContext$Deallocator.class */
        static class Deallocator extends CudnnContext implements Pointer.Deallocator {
            Deallocator(CudnnContext cudnnContext) {
                super(cudnnContext);
            }

            public void deallocate() {
                destroyHandles();
            }
        }

        CudnnContext() {
            this.srcTensorDesc = new cudnn.cudnnTensorStruct();
            this.dstTensorDesc = new cudnn.cudnnTensorStruct();
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct();
            this.poolingDesc = new cudnn.cudnnPoolingStruct();
            Nd4j.create(1);
            createHandles();
            deallocator(new Deallocator(this));
        }

        CudnnContext(CudnnContext cudnnContext) {
            super(cudnnContext);
            this.srcTensorDesc = new cudnn.cudnnTensorStruct();
            this.dstTensorDesc = new cudnn.cudnnTensorStruct();
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct();
            this.poolingDesc = new cudnn.cudnnPoolingStruct();
            this.srcTensorDesc = new cudnn.cudnnTensorStruct(cudnnContext.srcTensorDesc);
            this.dstTensorDesc = new cudnn.cudnnTensorStruct(cudnnContext.dstTensorDesc);
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct(cudnnContext.deltaTensorDesc);
            this.poolingDesc = new cudnn.cudnnPoolingStruct(cudnnContext.poolingDesc);
        }

        void createHandles() {
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreate(this));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.srcTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dstTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.deltaTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreatePoolingDescriptor(this.poolingDesc));
        }

        void destroyHandles() {
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyPoolingDescriptor(this.poolingDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.srcTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dstTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.deltaTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroy(this));
        }
    }

    public CudnnSubsamplingHelper() {
        this.dataType = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 1 : Nd4j.dataType() == DataBuffer.Type.FLOAT ? 0 : 2;
        this.tensorFormat = 0;
        this.alpha = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? new DoublePointer(new double[]{1.0d}) : Nd4j.dataType() == DataBuffer.Type.FLOAT ? new FloatPointer(new float[]{1.0f}) : new ShortPointer(new short[]{(short) HalfIndexer.fromFloat(1.0f)});
        this.beta = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? new DoublePointer(new double[]{0.0d}) : Nd4j.dataType() == DataBuffer.Type.FLOAT ? new FloatPointer(new float[]{0.0f}) : new ShortPointer(new short[]{(short) HalfIndexer.fromFloat(0.0f)});
        this.reduced = null;
    }

    static void checkCuda(int i) {
        if (i != 0) {
            throw new RuntimeException("CUDA error = " + i);
        }
    }

    static void checkCudnn(int i) {
        if (i != 0) {
            throw new RuntimeException("cuDNN status = " + i);
        }
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, INDArray iNDArray2, int[] iArr, int[] iArr2, int[] iArr3, SubsamplingLayer.PoolingType poolingType) {
        int i;
        int size = iNDArray.size(0);
        int size2 = iNDArray.size(1);
        int size3 = iNDArray.size(2);
        int size4 = iNDArray.size(3);
        int outSize = Convolution.outSize(size3, iArr[0], iArr2[0], iArr3[0], false);
        int outSize2 = Convolution.outSize(size4, iArr[1], iArr2[1], iArr3[1], false);
        DefaultGradient defaultGradient = new DefaultGradient();
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$SubsamplingLayer$PoolingType[poolingType.ordinal()]) {
            case 1:
                i = 1;
                break;
            case 2:
                i = 0;
                break;
            case 3:
                return new Pair<>(defaultGradient, iNDArray2);
            default:
                return null;
        }
        if (!Shape.strideDescendingCAscendingF(iNDArray2)) {
            iNDArray2 = iNDArray2.dup();
        }
        int[] stride = iNDArray.stride();
        int[] stride2 = iNDArray2.stride();
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.srcTensorDesc, this.dataType, size, size2, size3, size4, stride[0], stride[1], stride[2], stride[3]));
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.deltaTensorDesc, this.dataType, size, size2, outSize, outSize2, stride2[0], stride2[1], stride2[2], stride2[3]));
        checkCudnn(cudnn.cudnnSetPooling2dDescriptor(this.cudnnContext.poolingDesc, i, 1, iArr[0], iArr[1], iArr3[0], iArr3[1], iArr2[0], iArr2[1]));
        INDArray create = Nd4j.create(new int[]{size, size2, size3, size4}, 'c');
        int[] stride3 = create.stride();
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.dstTensorDesc, this.dataType, size, size2, size3, size4, stride3[0], stride3[1], stride3[2], stride3[3]));
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray, new INDArray[]{iNDArray2, this.reduced, create});
        Pointer pointer = atomicAllocator.getPointer(iNDArray, prepareAction);
        Pointer pointer2 = atomicAllocator.getPointer(iNDArray2, prepareAction);
        Pointer pointer3 = atomicAllocator.getPointer(this.reduced, prepareAction);
        Pointer pointer4 = atomicAllocator.getPointer(create, prepareAction);
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, new cuda.CUstream_st(prepareAction.getOldStream())));
        checkCudnn(cudnn.cudnnPoolingBackward(this.cudnnContext, this.cudnnContext.poolingDesc, this.alpha, this.cudnnContext.deltaTensorDesc, pointer3, this.cudnnContext.deltaTensorDesc, pointer2, this.cudnnContext.srcTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer4));
        atomicAllocator.registerAction(prepareAction, iNDArray, new INDArray[]{iNDArray2, this.reduced, create});
        return new Pair<>(defaultGradient, create);
    }

    public INDArray activate(INDArray iNDArray, boolean z, int[] iArr, int[] iArr2, int[] iArr3, SubsamplingLayer.PoolingType poolingType) {
        int i;
        int size = iNDArray.size(0);
        int size2 = iNDArray.size(1);
        int size3 = iNDArray.size(2);
        int size4 = iNDArray.size(3);
        int outSize = Convolution.outSize(size3, iArr[0], iArr2[0], iArr3[0], false);
        int outSize2 = Convolution.outSize(size4, iArr[1], iArr2[1], iArr3[1], false);
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$SubsamplingLayer$PoolingType[poolingType.ordinal()]) {
            case 1:
                i = 1;
                break;
            case 2:
                i = 0;
                break;
            case 3:
                return iNDArray;
            default:
                return null;
        }
        int[] stride = iNDArray.stride();
        checkCudnn(cudnn.cudnnSetPooling2dDescriptor(this.cudnnContext.poolingDesc, i, 1, iArr[0], iArr[1], iArr3[0], iArr3[1], iArr2[0], iArr2[1]));
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.srcTensorDesc, this.dataType, size, size2, size3, size4, stride[0], stride[1], stride[2], stride[3]));
        this.reduced = Nd4j.createUninitialized(new int[]{size, size2, outSize, outSize2}, 'c');
        int[] stride2 = this.reduced.stride();
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.dstTensorDesc, this.dataType, size, size2, outSize, outSize2, stride2[0], stride2[1], stride2[2], stride2[3]));
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray, new INDArray[]{this.reduced});
        Pointer pointer = atomicAllocator.getPointer(iNDArray, prepareAction);
        Pointer pointer2 = atomicAllocator.getPointer(this.reduced, prepareAction);
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, new cuda.CUstream_st(prepareAction.getOldStream())));
        checkCudnn(cudnn.cudnnPoolingForward(this.cudnnContext, this.cudnnContext.poolingDesc, this.alpha, this.cudnnContext.srcTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer2));
        atomicAllocator.registerAction(prepareAction, iNDArray, new INDArray[]{this.reduced});
        return this.reduced;
    }
}
