package org.deeplearning4j.nn.layers.normalization;

import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/CudnnLocalResponseNormalizationHelper.class */
public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper implements LocalResponseNormalizationHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnLocalResponseNormalizationHelper.class);
    private CudnnLocalResponseNormalizationContext cudnnContext = new CudnnLocalResponseNormalizationContext();
    private INDArray activations = null;

    /* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/CudnnLocalResponseNormalizationHelper$CudnnLocalResponseNormalizationContext.class */
    private static class CudnnLocalResponseNormalizationContext extends BaseCudnnHelper.CudnnContext {
        private cudnn.cudnnTensorStruct srcTensorDesc;
        private cudnn.cudnnTensorStruct dstTensorDesc;
        private cudnn.cudnnTensorStruct deltaTensorDesc;
        private cudnn.cudnnLRNStruct lrnDesc;

        /* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/CudnnLocalResponseNormalizationHelper$CudnnLocalResponseNormalizationContext$Deallocator.class */
        private static class Deallocator extends CudnnLocalResponseNormalizationContext implements Pointer.Deallocator {
            Deallocator(CudnnLocalResponseNormalizationContext cudnnLocalResponseNormalizationContext) {
                super(cudnnLocalResponseNormalizationContext);
            }

            public void deallocate() {
                destroyHandles();
            }
        }

        public CudnnLocalResponseNormalizationContext() {
            this.srcTensorDesc = new cudnn.cudnnTensorStruct();
            this.dstTensorDesc = new cudnn.cudnnTensorStruct();
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct();
            this.lrnDesc = new cudnn.cudnnLRNStruct();
            createHandles();
            deallocator(new Deallocator(this));
        }

        public CudnnLocalResponseNormalizationContext(CudnnLocalResponseNormalizationContext cudnnLocalResponseNormalizationContext) {
            super(cudnnLocalResponseNormalizationContext);
            this.srcTensorDesc = new cudnn.cudnnTensorStruct();
            this.dstTensorDesc = new cudnn.cudnnTensorStruct();
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct();
            this.lrnDesc = new cudnn.cudnnLRNStruct();
            this.srcTensorDesc = new cudnn.cudnnTensorStruct(cudnnLocalResponseNormalizationContext.srcTensorDesc);
            this.dstTensorDesc = new cudnn.cudnnTensorStruct(cudnnLocalResponseNormalizationContext.dstTensorDesc);
            this.deltaTensorDesc = new cudnn.cudnnTensorStruct(cudnnLocalResponseNormalizationContext.deltaTensorDesc);
            this.lrnDesc = new cudnn.cudnnLRNStruct(cudnnLocalResponseNormalizationContext.lrnDesc);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.nn.layers.BaseCudnnHelper.CudnnContext
        public void createHandles() {
            super.createHandles();
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.srcTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dstTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.deltaTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateLRNDescriptor(this.lrnDesc));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.nn.layers.BaseCudnnHelper.CudnnContext
        public void destroyHandles() {
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyLRNDescriptor(this.lrnDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.srcTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dstTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.deltaTensorDesc));
            super.destroyHandles();
        }
    }

    public boolean checkSupported(double d, double d2, double d3, double d4) {
        boolean checkSupported = checkSupported();
        if (d2 < 1.0d) {
            checkSupported = false;
            log.warn("Not supported: n < CUDNN_LRN_MIN_N (" + d2 + " < 1)");
        }
        if (d2 > 16.0d) {
            checkSupported = false;
            log.warn("Not supported: n > CUDNN_LRN_MAX_N (" + d2 + " > 16)");
        }
        if (d < 1.0E-5d) {
            checkSupported = false;
            log.warn("Not supported: k < CUDNN_LRN_MIN_K (" + d + " < 1.0E-5)");
        }
        if (d4 < 0.01d) {
            checkSupported = false;
            log.warn("Not supported: beta < CUDNN_LRN_MIN_BETA (" + d4 + " < 0.01)");
        }
        return checkSupported;
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, INDArray iNDArray2, double d, double d2, double d3, double d4, LayerWorkspaceMgr layerWorkspaceMgr) {
        int size = iNDArray.size(0);
        int size2 = iNDArray.size(1);
        int size3 = iNDArray.size(2);
        int size4 = iNDArray.size(3);
        DefaultGradient defaultGradient = new DefaultGradient();
        if (!Shape.hasDefaultStridesForShape(iNDArray2)) {
            iNDArray2 = iNDArray2.dup('c');
        }
        int[] stride = iNDArray.stride();
        int[] stride2 = iNDArray2.stride();
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.srcTensorDesc, this.dataType, size, size2, size3, size4, stride[0], stride[1], stride[2], stride[3]));
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.deltaTensorDesc, this.dataType, size, size2, size3, size4, stride2[0], stride2[1], stride2[2], stride2[3]));
        checkCudnn(cudnn.cudnnSetLRNDescriptor(this.cudnnContext.lrnDesc, (int) d2, d3, d4, d));
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new int[]{size, size2, size3, size4}, 'c');
        int[] stride3 = createUninitialized.stride();
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.dstTensorDesc, this.dataType, size, size2, size3, size4, stride3[0], stride3[1], stride3[2], stride3[3]));
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareActionAllWrite = atomicAllocator.getFlowController().prepareActionAllWrite(new INDArray[]{iNDArray, iNDArray2, this.activations, createUninitialized});
        Pointer pointer = atomicAllocator.getPointer(iNDArray, prepareActionAllWrite);
        Pointer pointer2 = atomicAllocator.getPointer(iNDArray2, prepareActionAllWrite);
        Pointer pointer3 = atomicAllocator.getPointer(this.activations, prepareActionAllWrite);
        Pointer pointer4 = atomicAllocator.getPointer(createUninitialized, prepareActionAllWrite);
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, new cuda.CUstream_st(prepareActionAllWrite.getOldStream())));
        checkCudnn(cudnn.cudnnLRNCrossChannelBackward(this.cudnnContext, this.cudnnContext.lrnDesc, 0, this.alpha, this.cudnnContext.deltaTensorDesc, pointer3, this.cudnnContext.deltaTensorDesc, pointer2, this.cudnnContext.srcTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer4));
        atomicAllocator.getFlowController().registerActionAllWrite(prepareActionAllWrite, new INDArray[]{iNDArray, iNDArray2, this.activations, createUninitialized});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            prepareActionAllWrite.syncOldStream();
        }
        return new Pair<>(defaultGradient, createUninitialized);
    }

    public INDArray activate(INDArray iNDArray, boolean z, double d, double d2, double d3, double d4, LayerWorkspaceMgr layerWorkspaceMgr) {
        int size = iNDArray.size(0);
        int size2 = iNDArray.size(1);
        int size3 = iNDArray.size(2);
        int size4 = iNDArray.size(3);
        if (!Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = iNDArray.dup('c');
        }
        int[] stride = iNDArray.stride();
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.srcTensorDesc, this.dataType, size, size2, size3, size4, stride[0], stride[1], stride[2], stride[3]));
        this.activations = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[]{size, size2, size3, size4}, 'c');
        int[] stride2 = this.activations.stride();
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.dstTensorDesc, this.dataType, size, size2, size3, size4, stride2[0], stride2[1], stride2[2], stride2[3]));
        checkCudnn(cudnn.cudnnSetLRNDescriptor(this.cudnnContext.lrnDesc, (int) d2, d3, d4, d));
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareActionAllWrite = atomicAllocator.getFlowController().prepareActionAllWrite(new INDArray[]{iNDArray, this.activations});
        Pointer pointer = atomicAllocator.getPointer(iNDArray, prepareActionAllWrite);
        Pointer pointer2 = atomicAllocator.getPointer(this.activations, prepareActionAllWrite);
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, new cuda.CUstream_st(prepareActionAllWrite.getOldStream())));
        checkCudnn(cudnn.cudnnLRNCrossChannelForward(this.cudnnContext, this.cudnnContext.lrnDesc, 0, this.alpha, this.cudnnContext.srcTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer2));
        atomicAllocator.getFlowController().registerActionAllWrite(prepareActionAllWrite, new INDArray[]{iNDArray, this.activations});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            prepareActionAllWrite.syncOldStream();
        }
        return this.activations;
    }
}
