package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.base.Preconditions;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.TadPack;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.api.shape.options.ArrayType;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCuda;
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariablesSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.class */
public class CudaExecutioner extends DefaultOpExecutioner {
    private static final Logger log = LoggerFactory.getLogger(CudaExecutioner.class);
    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected static TADManager tadManager = new DeviceTADManager();
    protected volatile transient Properties properties;
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal<>();
    protected ThreadLocal<String> lastOp = new ThreadLocal<>();
    protected Map<String, CustomOpDescriptor> customOps = null;
    protected AtomicBoolean experimentalMode = new AtomicBoolean(false);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$ops$Op$Type;
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.HALF.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$nd4j$linalg$api$ops$Op$Type = new int[Op.Type.values().length];
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.BROADCAST.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.BROADCAST_BOOL.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_LONG.ordinal()] = 3;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_BOOL.ordinal()] = 4;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_FLOAT.ordinal()] = 5;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_SAME.ordinal()] = 6;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.SCALAR.ordinal()] = 7;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.SCALAR_BOOL.ordinal()] = 8;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_BOOL.ordinal()] = 9;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.PAIRWISE_BOOL.ordinal()] = 10;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_ANY.ordinal()] = 11;
            } catch (NoSuchFieldError e14) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_FLOAT.ordinal()] = 12;
            } catch (NoSuchFieldError e15) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_SAME.ordinal()] = 13;
            } catch (NoSuchFieldError e16) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_STRICT.ordinal()] = 14;
            } catch (NoSuchFieldError e17) {
            }
        }
    }

    public CudaExecutioner() {
        this.experimentalMode.set(nativeOps.isExperimentalEnabled());
    }

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public String getLastOp() {
        return this.lastOp.get();
    }

    public INDArray exec(BroadcastOp broadcastOp) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(broadcastOp, new DataBuffer[0]);
        checkForCompression(broadcastOp);
        int[] intVector = broadcastOp.dimensions().toIntVector();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(broadcastOp.opName());
        }
        Pointer retrieveHostPointer = broadcastOp.y() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = broadcastOp.z() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.z().shapeInfoDataBuffer());
        Pointer pointer = AtomicAllocator.getInstance().getPointer(broadcastOp.x(), prepareAction);
        Pointer pointer2 = AtomicAllocator.getInstance().getPointer(broadcastOp.y(), prepareAction);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer(broadcastOp.z(), prepareAction);
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), prepareAction);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), intVector);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        Pointer pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(broadcastOp.z(), intVector);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(broadcastOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer5, pointer6, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction)});
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(intVector), prepareAction);
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[broadcastOp.getOpType().ordinal()]) {
            case 1:
                nativeOps.execBroadcast(put, broadcastOp.opNum(), (Pointer) null, AtomicAllocator.getInstance().getHostPointer(broadcastOp.x().shapeInfoDataBuffer()), pointer, pointer4, (Pointer) null, AtomicAllocator.getInstance().getHostPointer(broadcastOp.y().shapeInfoDataBuffer()), pointer2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction), (Pointer) null, AtomicAllocator.getInstance().getHostPointer(broadcastOp.z().shapeInfoDataBuffer()), pointer3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction), (Pointer) null, broadcastOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(broadcastOp.dimensions(), prepareAction), (LongPointer) null);
                break;
            case 2:
                nativeOps.execBroadcastBool(put, broadcastOp.opNum(), (Pointer) null, AtomicAllocator.getInstance().getHostPointer(broadcastOp.x().shapeInfoDataBuffer()), pointer, pointer4, (Pointer) null, AtomicAllocator.getInstance().getHostPointer(broadcastOp.y().shapeInfoDataBuffer()), pointer2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction), (Pointer) null, AtomicAllocator.getInstance().getHostPointer(broadcastOp.z().shapeInfoDataBuffer()), pointer3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction), (Pointer) null, broadcastOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(broadcastOp.dimensions(), prepareAction), (LongPointer) null);
                break;
            default:
                throw new UnsupportedOperationException("Unknown op type: " + broadcastOp.getOpType());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        profilingConfigurableHookOut(broadcastOp, profilingConfigurableHookIn);
        return broadcastOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray naiveExec(ReduceOp reduceOp, int... iArr) {
        DataType dataType;
        long profilingConfigurableHookIn = profilingConfigurableHookIn(reduceOp, new DataBuffer[0]);
        if ((reduceOp instanceof BaseReduceOp) && ((BaseReduceOp) reduceOp).isEmptyReduce()) {
            if (reduceOp.z() == null) {
                reduceOp.setZ(reduceOp.x().dup());
                return reduceOp.z();
            }
            Preconditions.checkState(reduceOp.x().equalShapes(reduceOp.z()), "For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", reduceOp.x(), reduceOp.z());
            reduceOp.z().assign(reduceOp.x());
            return reduceOp.z();
        }
        INDArray z = reduceOp.z();
        checkForCompression(reduceOp);
        reduceOp.validateDataTypes();
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= reduceOp.x().rank() && iArr[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(iArr) + " contains element that higher then rank of op.X: [" + reduceOp.x().rank() + "]");
            }
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(reduceOp.z(), reduceOp.x(), reduceOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(reduceOp.opName());
        }
        Pointer retrieveHostPointer = reduceOp.x() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = reduceOp.y() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = reduceOp.z() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(reduceOp.x(), iArr);
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        LongPointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer pointer2 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer(reduceOp.x(), prepareAction);
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer(reduceOp.x().shapeInfoDataBuffer(), prepareAction);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(reduceOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer, pointer2});
        Pointer pointer5 = null;
        Pointer pointer6 = null;
        if (reduceOp.y() != null) {
            if (iArr.length != 0 && (!(iArr.length == 1 && iArr[0] == Integer.MAX_VALUE) && reduceOp.x().tensorAlongDimension(0L, iArr).length() == reduceOp.y().length())) {
                DataBuffer constantBuffer = Nd4j.getConstantHandler().getConstantBuffer(new int[]{0, 0}, DataType.LONG);
                pointer5 = constantBuffer == null ? null : AtomicAllocator.getInstance().getPointer(constantBuffer, prepareAction);
                pointer6 = AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), prepareAction);
                put.put(12L, AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), prepareAction));
                put.put(13L, (Pointer) null);
            } else {
                if (!reduceOp.isComplexAccumulation() && reduceOp.x().length() != reduceOp.y().length()) {
                    throw new ND4JIllegalStateException("Op.X [" + reduceOp.x().length() + "] and Op.Y [" + reduceOp.y().length() + "] lengths should match");
                }
                if (!reduceOp.z().isScalar()) {
                    Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(reduceOp.y(), iArr);
                    pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction);
                    DataBuffer dataBuffer2 = (DataBuffer) tADOnlyShapeInfo2.getSecond();
                    pointer5 = dataBuffer2 == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer2, prepareAction);
                    put.put(12L, pointer6);
                    put.put(13L, pointer5);
                }
            }
        }
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
            case 3:
            case 4:
                dataType = reduceOp.x().dataType();
                break;
            default:
                dataType = reduceOp.z().dataType();
                break;
        }
        Pointer pointer7 = reduceOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(reduceOp.extraArgsDataBuff(dataType), prepareAction) : null;
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (reduceOp instanceof Variance) {
            if (z.isScalar()) {
                nativeOps.execSummaryStatsScalar(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()), ((Variance) reduceOp).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
            } else {
                nativeOps.execSummaryStatsTad(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction), (Pointer) null, reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null, ((Variance) reduceOp).isBiasCorrected(), pointer, (LongPointer) pointer2);
                AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
            }
        } else if (reduceOp.y() != null) {
            if (reduceOp.isComplexAccumulation()) {
                nativeOps.execReduce3All(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(reduceOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), prepareAction), (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction), (Pointer) null, reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null, pointer, new LongPointerWrapper(pointer2), (LongPointer) pointer6, new LongPointerWrapper(pointer5));
                AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
            } else if (z.isScalar()) {
                nativeOps.execReduce3Scalar(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(reduceOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), prepareAction), (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction));
                AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
            } else {
                nativeOps.execReduce3Tad(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(reduceOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), prepareAction), (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction), reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null, pointer, (LongPointer) pointer2, (LongPointer) pointer6, (LongPointer) pointer5);
                AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
            }
        } else if (z.isScalar()) {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                case 3:
                    nativeOps.execReduceLong(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()));
                    break;
                case 4:
                    nativeOps.execReduceBool(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()));
                    break;
                case Nd4jCuda.FLOAT32 /* 5 */:
                    nativeOps.execReduceFloat(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()));
                    break;
                case Nd4jCuda.DOUBLE /* 6 */:
                    nativeOps.execReduceSame(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()));
                    break;
                default:
                    throw new UnsupportedOperationException();
            }
            AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
        } else {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                case 3:
                    nativeOps.execReduceLong2(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction), reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null);
                    break;
                case 4:
                    nativeOps.execReduceBool2(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction), reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null);
                    break;
                case Nd4jCuda.FLOAT32 /* 5 */:
                    nativeOps.execReduceFloat2(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction), reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null);
                    break;
                case Nd4jCuda.DOUBLE /* 6 */:
                    nativeOps.execReduceSame2(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer3, pointer4, pointer7, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction), reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null);
                    break;
                default:
                    throw new UnsupportedOperationException();
            }
            AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(reduceOp, profilingConfigurableHookIn);
        return reduceOp.z();
    }

    public INDArray exec(Variance variance) {
        return exec((ReduceOp) variance);
    }

    public INDArray exec(ReduceOp reduceOp) {
        INDArray create;
        checkForCompression(reduceOp);
        if ((reduceOp instanceof BaseReduceOp) && ((BaseReduceOp) reduceOp).isEmptyReduce()) {
            if (reduceOp.z() == null) {
                reduceOp.setZ(reduceOp.x().dup());
                return reduceOp.z();
            }
            Preconditions.checkState(reduceOp.x().equalShapes(reduceOp.z()), "For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", reduceOp.x(), reduceOp.z());
            reduceOp.z().assign(reduceOp.x());
            return reduceOp.z();
        }
        int[] intVector = reduceOp.dimensions().toIntVector();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Shape.getMaxShape(new INDArray[]{reduceOp.x(), reduceOp.y()});
        boolean z = Shape.wholeArrayDimension(intVector) || reduceOp.x().rank() == intVector.length || intVector.length == 0;
        long[] reductionShape = Shape.reductionShape(reduceOp.y() == null ? reduceOp.x() : reduceOp.x().length() > reduceOp.y().length() ? reduceOp.x() : reduceOp.y(), intVector, true, reduceOp.isKeepDims());
        if (reduceOp.x().isVector() && reduceOp.x().length() == ArrayUtil.prod(reductionShape) && ArrayUtil.prodLong(reductionShape) > 1 && reduceOp.y() == null) {
            return reduceOp.noOp();
        }
        DataType resultType = reduceOp.resultType();
        if (reduceOp.z() == null || reduceOp.z() == reduceOp.x()) {
            if (reduceOp.isComplexAccumulation()) {
                create = Nd4j.createUninitialized(resultType, new long[]{reduceOp.x().tensorsAlongDimension(intVector), reduceOp.y().tensorsAlongDimension(intVector)});
            } else {
                if (reduceOp.y() != null) {
                    if (reduceOp.x().length() == reduceOp.y().length()) {
                        if (!z && reduceOp.x().tensorsAlongDimension(intVector) != reduceOp.y().tensorsAlongDimension(intVector)) {
                            throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(reduceOp.x().shape()) + ", y shape = " + Arrays.toString(reduceOp.y().shape()) + ", dimension = " + Arrays.toString(intVector) + ")");
                        }
                    } else {
                        if (intVector.length == 0) {
                            throw new ND4JIllegalStateException("TAD vs TAD comparison requires dimension (or other comparison mode was supposed to be used?)");
                        }
                        long length = reduceOp.x().length() / reduceOp.x().tensorsAlongDimension(intVector);
                        if (length != reduceOp.y().length()) {
                            throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + length + ", y size = " + reduceOp.y().length());
                        }
                    }
                }
                create = Nd4j.create(resultType, reductionShape);
            }
            reduceOp.setZ(create);
        } else {
            if (reduceOp.z().length() != (reductionShape.length == 0 ? 1L : ArrayUtil.prodLong(reductionShape))) {
                throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(reduceOp.z().shape()) + "] doesn't match expected [" + Arrays.toString(reductionShape) + "]");
            }
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(reduceOp, new DataBuffer[0]);
        naiveExec(reduceOp, intVector);
        profilingConfigurableHookOut(reduceOp, profilingConfigurableHookIn);
        return reduceOp.z();
    }

    public INDArray exec(IndexAccumulation indexAccumulation) {
        int[] normalizeAxis = Shape.normalizeAxis(indexAccumulation.x().rank(), indexAccumulation.dimensions().toIntVector());
        if (indexAccumulation.x().isEmpty()) {
            for (int i : normalizeAxis) {
                Preconditions.checkArgument(indexAccumulation.x().shape()[i] != 0, "IndexReduce can't be issued along axis with 0 in shape");
            }
        }
        if (indexAccumulation.z() == null) {
            indexAccumulation.setZ(Nd4j.createUninitialized(DataType.LONG, Shape.reductionShape(indexAccumulation.x(), normalizeAxis, true, indexAccumulation.isKeepDims())));
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(indexAccumulation, new DataBuffer[0]);
        checkForCompression(indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (indexAccumulation.x().isVector() && indexAccumulation.x().length() == indexAccumulation.z().length()) {
            return indexAccumulation.x();
        }
        if (indexAccumulation.z().isEmpty()) {
            return indexAccumulation.z();
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(indexAccumulation.opName());
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        Pointer retrieveHostPointer = indexAccumulation.x() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = indexAccumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = indexAccumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.z().shapeInfoDataBuffer());
        Pointer pointer = AtomicAllocator.getInstance().getPointer(indexAccumulation.x(), prepareAction);
        LongPointer pointer2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z(), prepareAction);
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), prepareAction);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), normalizeAxis);
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer5, dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction)});
        Pointer pointer6 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(indexAccumulation.x().dataType()), prepareAction) : null;
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(normalizeAxis), prepareAction);
        nativeOps.execIndexReduce(put, indexAccumulation.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer, pointer2, pointer6, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer3, pointer4, indexAccumulation.dimensions().data().addressPointer(), indexAccumulation.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(indexAccumulation.dimensions(), prepareAction), (LongPointer) null);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        profilingConfigurableHookOut(indexAccumulation, profilingConfigurableHookIn);
        return indexAccumulation.z();
    }

    public INDArray exec(Op op) {
        checkForCompression(op);
        if (op instanceof CopyOp) {
            if (op.x() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.x());
            }
            if (op.y() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.y());
            }
            super.exec(op);
            if (op.z() == null) {
                return null;
            }
            AtomicAllocator.getInstance().tickHostWrite(op.z());
            return null;
        }
        if (op instanceof TransformOp) {
            invoke((TransformOp) op);
        } else if (op instanceof ReduceOp) {
            ReduceOp reduceOp = (ReduceOp) op;
            invoke(reduceOp, reduceOp.dimensions().toIntVector());
        } else if (op instanceof ScalarOp) {
            invoke((ScalarOp) op);
        } else if (op instanceof BroadcastOp) {
            invoke((BroadcastOp) op);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation) op;
            invoke(indexAccumulation, indexAccumulation.dimensions().toIntVector());
        } else if (op instanceof RandomOp) {
            exec((RandomOp) op);
        } else if (op instanceof CustomOp) {
            exec((CustomOp) op);
        }
        return op.z();
    }

    public TransformOp execAndReturn(TransformOp transformOp) {
        checkForCompression(transformOp);
        invoke(transformOp);
        return transformOp;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(BroadcastOp broadcastOp) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(broadcastOp, new DataBuffer[0]);
        checkForCompression(broadcastOp);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(broadcastOp.opName());
        }
        Pointer pointer = AtomicAllocator.getInstance().getPointer(broadcastOp.x(), prepareAction);
        LongPointer pointer2 = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), prepareAction);
        Pointer retrieveHostPointer = broadcastOp.x() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = broadcastOp.y() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = broadcastOp.z() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), broadcastOp.getDimension());
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        Pointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(broadcastOp.z(), broadcastOp.getDimension());
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(broadcastOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer3, pointer4, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction)});
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer(broadcastOp.y(), prepareAction);
        LongPointer pointer6 = AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer7 = AtomicAllocator.getInstance().getPointer(broadcastOp.z(), prepareAction);
        LongPointer pointer8 = AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction);
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(broadcastOp.getDimension()), prepareAction);
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[broadcastOp.getOpType().ordinal()]) {
            case 1:
                nativeOps.execBroadcast(put, broadcastOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer, pointer2, (Pointer) null, (LongPointer) retrieveHostPointer2, pointer5, pointer6, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer7, pointer8, (Pointer) null, broadcastOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(broadcastOp.dimensions(), prepareAction), (LongPointer) null);
                break;
            case 2:
                nativeOps.execBroadcastBool(put, broadcastOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer, pointer2, (Pointer) null, (LongPointer) retrieveHostPointer2, pointer5, pointer6, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer7, pointer8, (Pointer) null, broadcastOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(broadcastOp.dimensions(), prepareAction), (LongPointer) null);
                break;
            default:
                throw new UnsupportedOperationException("Unknown opType: " + broadcastOp.getOpType());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        profilingConfigurableHookOut(broadcastOp, profilingConfigurableHookIn);
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(IndexAccumulation indexAccumulation, int[] iArr) {
        int[] normalizeAxis = Shape.normalizeAxis(indexAccumulation.x().rank(), iArr);
        if ((normalizeAxis == null || (normalizeAxis.length == 1 && normalizeAxis[0] == Integer.MAX_VALUE)) && (indexAccumulation.z() == indexAccumulation.x() || indexAccumulation.z() == null)) {
            indexAccumulation.setZ(Nd4j.createUninitialized(DataType.LONG, new long[0], 'c'));
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(indexAccumulation, new DataBuffer[0]);
        checkForCompression(indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(indexAccumulation.opName());
        }
        CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
        if (normalizeAxis != null) {
            for (int i = 0; i < normalizeAxis.length; i++) {
                if (normalizeAxis[i] >= indexAccumulation.x().rank() && normalizeAxis[i] != Integer.MAX_VALUE) {
                    throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(normalizeAxis) + " contains element that higher then rank of op.X: [" + indexAccumulation.x().rank() + "]");
                }
            }
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(indexAccumulation.z().isScalar() ? null : indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        Pointer pointer = AtomicAllocator.getInstance().getPointer(indexAccumulation.x(), prepareAction);
        LongPointer pointer2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer3 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(indexAccumulation.x().dataType()), prepareAction) : null;
        Pointer retrieveHostPointer = indexAccumulation.x() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = indexAccumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = indexAccumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.z().shapeInfoDataBuffer());
        int[] iArr2 = normalizeAxis;
        if (iArr2 == null) {
            iArr2 = new int[]{0};
        }
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), iArr2);
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer pointer5 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction);
        Pointer pointer6 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z(), prepareAction);
        LongPointer pointer7 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), prepareAction);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer4, pointer5});
        if (indexAccumulation.z().isScalar() || normalizeAxis == null || normalizeAxis[0] == Integer.MAX_VALUE) {
            nativeOps.execIndexReduceScalar(put, indexAccumulation.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer6, pointer7);
            AtomicAllocator.getInstance().registerAction(prepareAction, null, indexAccumulation.x(), indexAccumulation.y());
        } else {
            Arrays.sort(normalizeAxis);
            nativeOps.execIndexReduce(put, indexAccumulation.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer6, pointer7, AtomicAllocator.getInstance().getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(normalizeAxis)), indexAccumulation.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(indexAccumulation.dimensions(), prepareAction), (LongPointer) null);
            AtomicAllocator.getInstance().registerAction(prepareAction, null, indexAccumulation.x(), indexAccumulation.y());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(indexAccumulation, profilingConfigurableHookIn);
        return null;
    }

    protected CudaContext invoke(ReduceOp reduceOp, int[] iArr) {
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(reduceOp.z(), reduceOp.x(), reduceOp.y());
        if ((reduceOp instanceof BaseReduceOp) && ((BaseReduceOp) reduceOp).isEmptyReduce()) {
            if (reduceOp.z() == null) {
                reduceOp.setZ(reduceOp.x().dup());
                return prepareAction;
            }
            Preconditions.checkState(reduceOp.x().equalShapes(reduceOp.z()), "For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", reduceOp.x(), reduceOp.z());
            reduceOp.z().assign(reduceOp.x());
            return prepareAction;
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(reduceOp, new DataBuffer[0]);
        checkForCompression(reduceOp);
        int[] normalizeAxis = Shape.normalizeAxis(reduceOp.x().rank(), iArr);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (normalizeAxis == null) {
            normalizeAxis = new int[]{Nd4jCuda.MAX_DIMENSION};
        }
        if (normalizeAxis.length > 1) {
            Arrays.sort(normalizeAxis);
        }
        for (int i = 0; i < normalizeAxis.length; i++) {
            if (normalizeAxis[i] >= reduceOp.x().rank() && normalizeAxis[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(normalizeAxis) + " contains element that higher then rank of op.X: [" + reduceOp.x().rank() + "]");
            }
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(reduceOp.opName());
        }
        Pair makePair = reduceOp.x().isEmpty() ? Pair.makePair(reduceOp.x().data(), (Object) null) : tadManager.getTADOnlyShapeInfo(reduceOp.x(), normalizeAxis);
        Pointer retrieveHostPointer = AddressRetriever.retrieveHostPointer((DataBuffer) makePair.getFirst());
        LongPointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) makePair.getFirst(), prepareAction);
        DataBuffer dataBuffer = reduceOp.x().isEmpty() ? null : (DataBuffer) makePair.getSecond();
        Pointer pointer2 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer(reduceOp.x(), prepareAction);
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer(reduceOp.x().shapeInfoDataBuffer(), prepareAction);
        long[] reductionShape = Shape.reductionShape(reduceOp.x(), normalizeAxis, true, reduceOp.isKeepDims());
        if (reduceOp.y() != null) {
            if (reduceOp.x().length() != reduceOp.y().length()) {
                long length = reduceOp.x().length() / reduceOp.x().tensorsAlongDimension(normalizeAxis);
                if (length != reduceOp.y().length()) {
                    throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + length + ", y size = " + reduceOp.y().length());
                }
            } else if (reduceOp.x().tensorsAlongDimension(normalizeAxis) != reduceOp.y().tensorsAlongDimension(normalizeAxis)) {
                throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(reduceOp.x().shape()) + ", y shape = " + Arrays.toString(reduceOp.y().shape()) + ", dimension = " + Arrays.toString(normalizeAxis) + ")");
            }
        }
        if (reduceOp.x().isVector() && reduceOp.x().length() == ArrayUtil.prod(reductionShape)) {
            return null;
        }
        reduceOp.setZ(Nd4j.createUninitialized(reduceOp.resultType(), reductionShape));
        Pointer pointer5 = reduceOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(reduceOp.extraArgsDataBuff((reduceOp.z().dataType() == DataType.BOOL || reduceOp.getOpType() == Op.Type.REDUCE_LONG) ? reduceOp.x().dataType() : reduceOp.z().dataType()), prepareAction) : null;
        Pointer retrieveHostPointer2 = reduceOp.x() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = reduceOp.y() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer4 = reduceOp.z() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.z().shapeInfoDataBuffer());
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(reduceOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer3, retrieveHostPointer4, retrieveHostPointer, pointer, pointer2});
        Pair tADOnlyShapeInfo = reduceOp.y() == null ? null : tadManager.getTADOnlyShapeInfo(reduceOp.y(), normalizeAxis);
        Pointer pointer6 = reduceOp.y() == null ? null : AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer2 = reduceOp.y() == null ? null : (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer pointer7 = dataBuffer2 == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer2, prepareAction);
        if (reduceOp.y() != null) {
            put.put(12L, pointer6);
            put.put(13L, pointer7);
        }
        Pointer pointer8 = AtomicAllocator.getInstance().getPointer(reduceOp.z(), prepareAction);
        LongPointer pointer9 = AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), prepareAction);
        reduceOp.validateDataTypes();
        if (!reduceOp.z().isScalar()) {
            Pointer pointer10 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(normalizeAxis), prepareAction);
            if (reduceOp.y() != null) {
                nativeOps.execReduce3Tad(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), prepareAction), (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9, reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), pointer10, (LongPointer) null, pointer, (LongPointer) pointer2, (LongPointer) pointer6, (LongPointer) pointer7);
            } else if (reduceOp instanceof Variance) {
                nativeOps.execSummaryStatsTad(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9, reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null, ((Variance) reduceOp).isBiasCorrected(), pointer, (LongPointer) pointer2);
            } else {
                switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                    case 3:
                        nativeOps.execReduceLong2(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9, reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null);
                        break;
                    case 4:
                        nativeOps.execReduceBool2(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9, reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null);
                        break;
                    case Nd4jCuda.FLOAT32 /* 5 */:
                        nativeOps.execReduceFloat2(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9, reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null);
                        break;
                    case Nd4jCuda.DOUBLE /* 6 */:
                        nativeOps.execReduceSame2(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9, reduceOp.dimensions().data().addressPointer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(reduceOp.dimensions(), prepareAction), (LongPointer) null);
                        break;
                    default:
                        throw new UnsupportedOperationException();
                }
            }
            AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
        } else if (reduceOp instanceof Variance) {
            nativeOps.execSummaryStatsScalar(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9, ((Variance) reduceOp).isBiasCorrected());
            AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
        } else if (reduceOp.y() != null) {
            nativeOps.execReduce3Scalar(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), prepareAction), (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9);
            AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
        } else {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                case 3:
                    nativeOps.execReduceLong(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9);
                    break;
                case 4:
                    nativeOps.execReduceBool(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9);
                    break;
                case Nd4jCuda.FLOAT32 /* 5 */:
                    nativeOps.execReduceFloat(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9);
                    break;
                case Nd4jCuda.DOUBLE /* 6 */:
                    nativeOps.execReduceSame(put, reduceOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer2, pointer3, pointer4, pointer5, (Pointer) null, (LongPointer) retrieveHostPointer4, pointer8, pointer9);
                    break;
                default:
                    throw new UnsupportedOperationException();
            }
            AtomicAllocator.getInstance().registerAction(prepareAction, reduceOp.z(), reduceOp.x(), reduceOp.y());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(reduceOp, profilingConfigurableHookIn);
        Nd4j.getExecutioner().commit();
        return prepareAction;
    }

    protected CudaContext intercept(ScalarOp scalarOp, int[] iArr) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(scalarOp, new DataBuffer[0]);
        Arrays.sort(iArr);
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(scalarOp.z(), scalarOp.x(), scalarOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(scalarOp.opName());
        }
        Pointer retrieveHostPointer = scalarOp.x() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = scalarOp.y() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = scalarOp.z() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.z().shapeInfoDataBuffer());
        Pointer pointer = AtomicAllocator.getInstance().getPointer(scalarOp.x(), prepareAction);
        Pointer pointer2 = AtomicAllocator.getInstance().getPointer(scalarOp.y(), prepareAction);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer(scalarOp.z(), prepareAction);
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer(scalarOp.x().shapeInfoDataBuffer(), prepareAction);
        LongPointer pointer5 = AtomicAllocator.getInstance().getPointer(scalarOp.y().shapeInfoDataBuffer(), prepareAction);
        LongPointer pointer6 = AtomicAllocator.getInstance().getPointer(scalarOp.z().shapeInfoDataBuffer(), prepareAction);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(scalarOp.x(), iArr);
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        LongPointer pointer7 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        LongPointer pointer8 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(scalarOp.z(), iArr);
        LongPointer pointer9 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction);
        LongPointer pointer10 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer7, pointer8, pointer9, pointer10});
        Pointer pointer11 = scalarOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(scalarOp.extraArgsDataBuff(scalarOp.z().dataType()), prepareAction) : null;
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[scalarOp.getOpType().ordinal()]) {
            case Nd4jCuda.INT8 /* 7 */:
                nativeOps.execScalarTad(put, scalarOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer, pointer4, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer3, pointer6, (Pointer) null, (LongPointer) retrieveHostPointer2, pointer2, pointer5, pointer11, (Pointer) null, scalarOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(scalarOp.dimensions(), prepareAction), (LongPointer) null, pointer7, pointer8, pointer9, pointer10);
                break;
            case Nd4jCuda.INT16 /* 8 */:
                nativeOps.execScalarBoolTad(put, scalarOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer, pointer4, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer3, pointer6, (Pointer) null, (LongPointer) retrieveHostPointer2, pointer2, pointer5, pointer11, (Pointer) null, scalarOp.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(scalarOp.dimensions(), prepareAction), (LongPointer) null, pointer7, pointer8, pointer9, pointer10);
                break;
            default:
                throw new UnsupportedOperationException();
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, scalarOp.z(), scalarOp.x(), scalarOp.y());
        profilingConfigurableHookOut(scalarOp, profilingConfigurableHookIn);
        return null;
    }

    public INDArray exec(ScalarOp scalarOp) {
        invoke(scalarOp);
        return scalarOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(ScalarOp scalarOp) {
        Pointer pointer;
        long profilingConfigurableHookIn = profilingConfigurableHookIn(scalarOp, new DataBuffer[0]);
        checkForCompression(scalarOp);
        if (scalarOp.x().length() != scalarOp.z().length()) {
            throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(scalarOp.x().shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(scalarOp.z().shapeInfoDataBuffer().asInt()) + "]");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(scalarOp.opName());
        }
        if (scalarOp.dimensions() != null) {
            intercept(scalarOp, scalarOp.dimensions().toIntVector());
            return null;
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(scalarOp.z(), scalarOp.x(), scalarOp.y());
        Pointer retrieveHostPointer = scalarOp.x() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = scalarOp.scalar() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.scalar().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = scalarOp.z() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.z().shapeInfoDataBuffer());
        Pointer pointer2 = AtomicAllocator.getInstance().getPointer(scalarOp.x(), prepareAction);
        LongPointer pointer3 = AtomicAllocator.getInstance().getPointer(scalarOp.x().shapeInfoDataBuffer(), prepareAction);
        if (scalarOp.extraArgs() != null) {
            pointer = AtomicAllocator.getInstance().getPointer(scalarOp.extraArgsDataBuff(scalarOp.getOpType() == Op.Type.SCALAR_BOOL ? scalarOp.x().dataType() : scalarOp.z().dataType()), prepareAction);
        } else {
            pointer = null;
        }
        Pointer pointer4 = pointer;
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer(scalarOp.z(), prepareAction);
        LongPointer pointer6 = AtomicAllocator.getInstance().getPointer(scalarOp.z().shapeInfoDataBuffer(), prepareAction);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, null, null});
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[scalarOp.getOpType().ordinal()]) {
            case Nd4jCuda.INT8 /* 7 */:
                nativeOps.execScalar(put, scalarOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, (Pointer) null, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(scalarOp.scalar(), prepareAction), AtomicAllocator.getInstance().getPointer(scalarOp.scalar().shapeInfoDataBuffer(), prepareAction), pointer4);
                break;
            case Nd4jCuda.INT16 /* 8 */:
                nativeOps.execScalarBool(put, scalarOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, (Pointer) null, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(scalarOp.scalar(), prepareAction), AtomicAllocator.getInstance().getPointer(scalarOp.scalar().shapeInfoDataBuffer(), prepareAction), pointer4);
                break;
            default:
                throw new UnsupportedOperationException("Unknown op type: " + scalarOp.getOpType());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, scalarOp.z(), scalarOp.x(), scalarOp.scalar());
        profilingConfigurableHookOut(scalarOp, profilingConfigurableHookIn);
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(TransformOp transformOp) {
        Pointer pointer;
        long profilingConfigurableHookIn = profilingConfigurableHookIn(transformOp, new DataBuffer[0]);
        checkForCompression(transformOp);
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(transformOp.z(), transformOp.x(), transformOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(transformOp.opName());
        }
        INDArray iNDArray = null;
        Pointer pointer2 = atomicAllocator.getPointer(transformOp.x(), prepareAction);
        LongPointer pointer3 = atomicAllocator.getPointer(transformOp.x().shapeInfoDataBuffer(), prepareAction);
        Object[] objArr = null;
        Pointer retrieveHostPointer = transformOp.x() == null ? null : AddressRetriever.retrieveHostPointer(transformOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = transformOp.y() == null ? null : AddressRetriever.retrieveHostPointer(transformOp.y().shapeInfoDataBuffer());
        if (transformOp.z() == null) {
            iNDArray = Nd4j.createUninitialized(transformOp.resultType(), transformOp.x().shape(), transformOp.x().ordering());
            transformOp.setZ(iNDArray);
        }
        if (transformOp.extraArgs() != null) {
            pointer = atomicAllocator.getPointer(transformOp.extraArgsDataBuff((transformOp.getOpType() == Op.Type.TRANSFORM_BOOL || transformOp.getOpType() == Op.Type.PAIRWISE_BOOL) ? transformOp.x().dataType() : transformOp.z().dataType()), prepareAction);
        } else {
            pointer = null;
        }
        Pointer pointer4 = pointer;
        Pointer retrieveHostPointer3 = transformOp.z() == null ? null : AddressRetriever.retrieveHostPointer(transformOp.z().shapeInfoDataBuffer());
        transformOp.validateDataTypes(this.experimentalMode.get());
        Pointer pointer5 = atomicAllocator.getPointer(transformOp.z(), prepareAction);
        LongPointer pointer6 = atomicAllocator.getPointer(transformOp.z().shapeInfoDataBuffer(), prepareAction);
        PointerPointer pointerPointer = this.extraz.get();
        Pointer[] pointerArr = new Pointer[20];
        pointerArr[0] = AddressRetriever.retrieveHostPointer(transformOp.x().shapeInfoDataBuffer());
        pointerArr[1] = prepareAction.getOldStream();
        pointerArr[2] = atomicAllocator.getDeviceIdPointer();
        pointerArr[3] = prepareAction.getBufferAllocation();
        pointerArr[4] = prepareAction.getBufferReduction();
        pointerArr[5] = prepareAction.getBufferScalar();
        pointerArr[6] = prepareAction.getBufferSpecial();
        pointerArr[7] = retrieveHostPointer2;
        pointerArr[8] = retrieveHostPointer3;
        pointerArr[9] = null;
        pointerArr[10] = null;
        pointerArr[11] = null;
        pointerArr[12] = null;
        pointerArr[13] = null;
        pointerArr[14] = null;
        pointerArr[15] = null;
        pointerArr[16] = null;
        pointerArr[17] = null;
        pointerArr[18] = new CudaPointer(0 == 0 ? 0L : objArr.length);
        pointerArr[19] = null;
        PointerPointer put = pointerPointer.put(pointerArr);
        if (transformOp.y() == null) {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[transformOp.getOpType().ordinal()]) {
                case Nd4jCuda.INT32 /* 9 */:
                    nativeOps.execTransformBool(put, transformOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, pointer4);
                    break;
                case 10:
                default:
                    throw new UnsupportedOperationException();
                case Nd4jCuda.UINT8 /* 11 */:
                    nativeOps.execTransformAny(put, transformOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, pointer4);
                    break;
                case Nd4jCuda.UINT16 /* 12 */:
                    nativeOps.execTransformFloat(put, transformOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, pointer4);
                    break;
                case Nd4jCuda.UINT32 /* 13 */:
                    nativeOps.execTransformSame(put, transformOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, pointer4);
                    break;
                case Nd4jCuda.UINT64 /* 14 */:
                    nativeOps.execTransformStrict(put, transformOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, pointer4);
                    break;
            }
        } else {
            Pointer pointer7 = atomicAllocator.getPointer(transformOp.y(), prepareAction);
            LongPointer pointer8 = atomicAllocator.getPointer(transformOp.y().shapeInfoDataBuffer(), prepareAction);
            if (transformOp.x().length() != transformOp.y().length() || transformOp.x().length() != transformOp.z().length()) {
                throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform");
            }
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[transformOp.getOpType().ordinal()]) {
                case Nd4jCuda.INT32 /* 9 */:
                case 10:
                    nativeOps.execPairwiseTransformBool(put, transformOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer2, pointer7, pointer8, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, pointer4);
                    break;
                default:
                    nativeOps.execPairwiseTransform(put, transformOp.opNum(), (Pointer) null, (LongPointer) retrieveHostPointer, pointer2, pointer3, (Pointer) null, (LongPointer) retrieveHostPointer2, pointer7, pointer8, (Pointer) null, (LongPointer) retrieveHostPointer3, pointer5, pointer6, pointer4);
                    break;
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, transformOp.z(), transformOp.x(), transformOp.y());
        if (pointer4 != null) {
            pointer4.address();
        }
        if (iNDArray != null) {
            iNDArray.elementWiseStride();
        }
        profilingConfigurableHookOut(transformOp, profilingConfigurableHookIn);
        return null;
    }

    protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
        DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4, false);
        batch.setParamsSurface(createInt);
        return createInt;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        BaseCudaDataBuffer baseCudaDataBuffer = (BaseCudaDataBuffer) getBuffer(batch);
        baseCudaDataBuffer.lazyAllocateHostPointer();
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        IntPointer asIntPointer = new CudaPointer(AtomicAllocator.getInstance().getHostPointer(baseCudaDataBuffer)).asIntPointer();
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(baseCudaDataBuffer);
        int maxIntArrays = batch.getSample().maxIntArrays();
        int maxIntArraySize = batch.getSample().maxIntArraySize();
        int batchLimit = (((5 * (Batch.getBatchLimit() * 16)) + (batch.getSample().maxIndexArguments() * (Batch.getBatchLimit() * 16))) + ((maxIntArrays * maxIntArraySize) * (Batch.getBatchLimit() * 16))) / (Nd4j.dataType() == DataType.DOUBLE ? 2 : 1);
        if (Nd4j.dataType() == DataType.HALF) {
            batchLimit *= 2;
        }
        int maxRealArguments = (batchLimit + (batch.getSample().maxRealArguments() * (Batch.getBatchLimit() * 16))) / (Nd4j.dataType() == DataType.FLOAT ? 2 : 1);
        if (Nd4j.dataType() == DataType.HALF) {
            maxRealArguments /= 4;
        }
        int maxArguments = maxRealArguments + (batch.getSample().maxArguments() * Batch.getBatchLimit() * 16);
        DataType dataType = null;
        for (int i = 0; i < batch.getNumAggregates(); i++) {
            Aggregate aggregate = (Aggregate) batch.getAggregates().get(i);
            if (i == 0) {
                dataType = ((INDArray) aggregate.getArguments().get(0)).dataType();
            }
            asIntPointer.put(i * 5, aggregate.getArguments().size());
            asIntPointer.put(r0 + 1, aggregate.getShapes().size());
            asIntPointer.put(r0 + 2, aggregate.getIndexingArguments().size());
            asIntPointer.put(r0 + 3, aggregate.getRealArguments().size());
            asIntPointer.put(r0 + 4, aggregate.getIntArrayArguments().size());
            for (int i2 = 0; i2 < aggregate.getIndexingArguments().size(); i2++) {
                asIntPointer.put(r0 + (i * batch.getSample().maxIndexArguments()) + i2, ((Integer) aggregate.getIndexingArguments().get(i2)).intValue());
            }
            int i3 = maxIntArrays * maxIntArraySize;
            for (int i4 = 0; i4 < aggregate.getIntArrayArguments().size(); i4++) {
                int i5 = (i * i3) + (i4 * maxIntArraySize);
                if (aggregate.getIntArrayArguments().get(i4) != null) {
                    for (int i6 = 0; i6 < ((int[]) aggregate.getIntArrayArguments().get(i4)).length; i6++) {
                        asIntPointer.put(r0 + i5 + i6, ((int[]) aggregate.getIntArrayArguments().get(i4))[i6]);
                    }
                }
            }
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType.ordinal()]) {
                case 1:
                    FloatPointer floatPointer = new FloatPointer(asIntPointer);
                    for (int i7 = 0; i7 < aggregate.getRealArguments().size(); i7++) {
                        floatPointer.put(batchLimit + (i * aggregate.maxRealArguments()) + i7, ((Number) aggregate.getRealArguments().get(i7)).floatValue());
                    }
                    break;
                case 2:
                    DoublePointer doublePointer = new DoublePointer(asIntPointer);
                    for (int i8 = 0; i8 < aggregate.getRealArguments().size(); i8++) {
                        doublePointer.put(batchLimit + (i * aggregate.maxRealArguments()) + i8, ((Number) aggregate.getRealArguments().get(i8)).doubleValue());
                    }
                    break;
                case 3:
                    ShortPointer shortPointer = new ShortPointer(asIntPointer);
                    for (int i9 = 0; i9 < aggregate.getRealArguments().size(); i9++) {
                        shortPointer.put(batchLimit + (i * aggregate.maxRealArguments()) + i9, BaseDataBuffer.fromFloat(((Number) aggregate.getRealArguments().get(i9)).floatValue()));
                    }
                    break;
                default:
                    throw new UnsupportedOperationException("Unknown data type");
            }
            PointerPointer pointerPointer = new PointerPointer(asIntPointer);
            for (int i10 = 0; i10 < aggregate.getArguments().size(); i10++) {
                int maxArguments2 = maxRealArguments + (i * batch.getSample().maxArguments());
                if (aggregate.getArguments().get(i10) != null) {
                    pointerPointer.put(maxArguments2 + i10, AtomicAllocator.getInstance().getPointer((INDArray) aggregate.getArguments().get(i10), deviceContext));
                    AtomicAllocator.getInstance().getAllocationPoint((INDArray) aggregate.getArguments().get(i10)).tickDeviceWrite();
                }
            }
            for (int i11 = 0; i11 < aggregate.getShapes().size(); i11++) {
                int maxShapes = maxArguments + (i * batch.getSample().maxShapes());
                if (aggregate.getShapes().get(i11) != null) {
                    pointerPointer.put(maxShapes + i11, AtomicAllocator.getInstance().getPointer((DataBuffer) aggregate.getShapes().get(i11), deviceContext));
                    AtomicAllocator.getInstance().getAllocationPoint((DataBuffer) aggregate.getShapes().get(i11)).tickDeviceWrite();
                }
            }
        }
        allocationPoint.tickHostWrite();
        PointerPointer pointerPointer2 = new PointerPointer(32L);
        pointerPointer2.put(0L, (Pointer) null);
        pointerPointer2.put(1L, deviceContext.getOldStream());
        pointerPointer2.put(2L, new CudaPointer(Math.min(batch.getNumAggregates(), CudaEnvironment.getInstance().getConfiguration().getMaximumGridSize())));
        pointerPointer2.put(3L, new CudaPointer(batch.getSample().getThreadsPerInstance()));
        pointerPointer2.put(4L, new CudaPointer(batch.getSample().getSharedMemorySize()));
        nativeOps.execAggregateBatch(pointerPointer2, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(baseCudaDataBuffer, deviceContext), FlatBuffersMapper.getDataTypeAsByte(dataType));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        allocationPoint.tickHostWrite();
    }

    public void exec(List<Aggregate> list) {
        if (list.size() == 0) {
            return;
        }
        Iterator it = Batch.getBatches(list, 8192).iterator();
        while (it.hasNext()) {
            exec((Batch) it.next());
        }
        AtomicAllocator.getInstance().getDeviceContext().syncOldStream();
    }

    public void exec(Aggregate aggregate) {
        int size = aggregate.getArguments().size();
        int size2 = aggregate.getShapes().size();
        int size3 = aggregate.getIndexingArguments().size();
        int size4 = aggregate.getIntArrayArguments().size();
        int size5 = aggregate.getRealArguments().size();
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        PointerPointer pointerPointer = new PointerPointer(32L);
        pointerPointer.put(0L, (Pointer) null);
        pointerPointer.put(1L, deviceContext.getOldStream());
        pointerPointer.put(2L, new CudaPointer(1L));
        pointerPointer.put(3L, new CudaPointer(aggregate.getThreadsPerInstance()));
        pointerPointer.put(4L, new CudaPointer(aggregate.getSharedMemorySize()));
        long[] jArr = new long[size];
        DataType dataType = ((INDArray) aggregate.getArguments().get(0)).dataType();
        for (int i = 0; i < size; i++) {
            jArr[i] = aggregate.getArguments().get(i) == null ? 0L : AtomicAllocator.getInstance().getPointer((INDArray) aggregate.getArguments().get(i), deviceContext).address();
            if (aggregate.getArguments().get(i) != null) {
                AtomicAllocator.getInstance().getAllocationPoint((INDArray) aggregate.getArguments().get(i)).tickDeviceWrite();
            }
        }
        PointerPointer pointerPointer2 = new PointerPointer(AtomicAllocator.getInstance().getPointer(AllocationUtils.getPointersBuffer(jArr), deviceContext));
        long[] jArr2 = new long[size2];
        for (int i2 = 0; i2 < size2; i2++) {
            jArr2[i2] = aggregate.getShapes().get(i2) == null ? 0L : AtomicAllocator.getInstance().getPointer((DataBuffer) aggregate.getShapes().get(i2), deviceContext).address();
            if (aggregate.getShapes().get(i2) != null) {
                AtomicAllocator.getInstance().getAllocationPoint((DataBuffer) aggregate.getShapes().get(i2)).tickDeviceWrite();
            }
        }
        PointerPointer pointerPointer3 = new PointerPointer(AtomicAllocator.getInstance().getPointer(AllocationUtils.getPointersBuffer(jArr2), deviceContext));
        long[] jArr3 = new long[size4];
        for (int i3 = 0; i3 < size4; i3++) {
            if (aggregate.getIntArrayArguments().get(i3) != null) {
                jArr3[i3] = AtomicAllocator.getInstance().getPointer(Nd4j.getDataBufferFactory().createInt((int[]) aggregate.getIntArrayArguments().get(i3)), deviceContext).address();
            }
        }
        PointerPointer pointerPointer4 = new PointerPointer(AtomicAllocator.getInstance().getPointer(AllocationUtils.getPointersBuffer(jArr3), deviceContext));
        int[] iArr = new int[size3];
        for (int i4 = 0; i4 < size3; i4++) {
            iArr[i4] = ((Integer) aggregate.getIndexingArguments().get(i4)).intValue();
        }
        DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(iArr);
        double[] dArr = new double[size5];
        for (int i5 = 0; i5 < size5; i5++) {
            dArr[i5] = ((Number) aggregate.getRealArguments().get(i5)).doubleValue();
        }
        nativeOps.execAggregate(pointerPointer, aggregate.opNum(), pointerPointer2, size, pointerPointer3, size2, AtomicAllocator.getInstance().getPointer(createInt, deviceContext), size3, pointerPointer4, size4, AtomicAllocator.getInstance().getPointer(Nd4j.create(dArr, new long[]{dArr.length}, dataType).data(), deviceContext), size5, FlatBuffersMapper.getDataTypeAsByte(dataType));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public INDArray exec(RandomOp randomOp) {
        return exec(randomOp, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp randomOp, Random random) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(randomOp, new DataBuffer[0]);
        checkForCompression(randomOp);
        if (random.getStatePointer() == null) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(randomOp.opName());
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(randomOp.z(), randomOp.x(), randomOp.y());
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(randomOp.z().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        Pointer retrieveHostPointer = randomOp.x() == null ? null : AddressRetriever.retrieveHostPointer(randomOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = randomOp.y() == null ? null : AddressRetriever.retrieveHostPointer(randomOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = randomOp.z() == null ? null : AddressRetriever.retrieveHostPointer(randomOp.z().shapeInfoDataBuffer());
        if (randomOp.x() != null && randomOp.y() != null && randomOp.z() != null) {
            nativeOps.execRandom3(put, randomOp.opNum(), random.getStatePointer(), (Pointer) null, (LongPointer) retrieveHostPointer, AtomicAllocator.getInstance().getPointer(randomOp.x(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.x().shapeInfoDataBuffer(), prepareAction), (Pointer) null, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(randomOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.y().shapeInfoDataBuffer(), prepareAction), (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(randomOp.z().dataType()), prepareAction));
        } else if (randomOp.x() == null || randomOp.z() == null) {
            nativeOps.execRandom(put, randomOp.opNum(), random.getStatePointer(), (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(randomOp.z().dataType()), prepareAction));
        } else {
            nativeOps.execRandom2(put, randomOp.opNum(), random.getStatePointer(), (Pointer) null, (LongPointer) retrieveHostPointer, AtomicAllocator.getInstance().getPointer(randomOp.x(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.x().shapeInfoDataBuffer(), prepareAction), (Pointer) null, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(randomOp.z().dataType()), prepareAction));
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, randomOp.z(), randomOp.x(), randomOp.y());
        profilingConfigurableHookOut(randomOp, profilingConfigurableHookIn);
        return randomOp.z();
    }

    public synchronized Properties getEnvironmentInformation() {
        if (this.properties == null) {
            Properties environmentInformation = super.getEnvironmentInformation();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < nativeOps.getAvailableDevices(); i++) {
                HashMap hashMap = new HashMap();
                hashMap.put("cuda.deviceName", nativeOps.getDeviceName(i));
                hashMap.put("cuda.freeMemory", Long.valueOf(nativeOps.getDeviceFreeMemory(i)));
                hashMap.put("cuda.totalMemory", Long.valueOf(nativeOps.getDeviceTotalMemory(i)));
                hashMap.put("cuda.deviceMajor", Long.valueOf(nativeOps.getDeviceMajor(i)));
                hashMap.put("cuda.deviceMinor", Long.valueOf(nativeOps.getDeviceMinor(i)));
                arrayList.add(i, hashMap);
            }
            environmentInformation.put("backend", "CUDA");
            environmentInformation.put("cuda.availableDevices", Integer.valueOf(nativeOps.getAvailableDevices()));
            environmentInformation.put("cuda.devicesInformation", arrayList);
            environmentInformation.put("blas.vendor", Nd4j.factory().blas().getBlasVendor().toString());
            environmentInformation.put("memory.free", Long.valueOf(Pointer.maxBytes() - Pointer.totalBytes()));
            environmentInformation.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
            this.properties = environmentInformation;
        } else {
            List list = (List) this.properties.get("cuda.devicesInformation");
            for (int i2 = 0; i2 < nativeOps.getAvailableDevices(); i2++) {
                Map map = (Map) list.get(i2);
                map.put("cuda.freeMemory", Long.valueOf(nativeOps.getDeviceFreeMemory(i2)));
                map.put("cuda.totalMemory", Long.valueOf(nativeOps.getDeviceTotalMemory(i2)));
            }
            this.properties.put("cuda.devicesInformation", list);
            this.properties.put("memory.free", Long.valueOf(Pointer.maxBytes() - Pointer.totalBytes()));
            this.properties.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
        }
        return this.properties;
    }

    public TADManager getTADManager() {
        return tadManager;
    }

    public void printEnvironmentInformation() {
        super.printEnvironmentInformation();
        for (Map map : (List) getEnvironmentInformation().get("cuda.devicesInformation")) {
            log.info("Device Name: [{}]; CC: [{}.{}]; Total/free memory: [{}]", new Object[]{map.get("cuda.deviceName"), map.get("cuda.deviceMajor"), map.get("cuda.deviceMinor"), map.get("cuda.totalMemory")});
        }
    }

    public void commit() {
        AtomicAllocator.getInstance().getDeviceContext().syncOldStream();
        AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream();
    }

    public INDArray thresholdEncode(INDArray iNDArray, double d, Integer num) {
        DataBuffer data = iNDArray.data();
        int length = (int) ((data.length() / Nd4jCuda.MAX_NUM_THREADS) + (data.length() % ((long) Nd4jCuda.MAX_NUM_THREADS) == 0 ? 0 : 1));
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        DataBuffer createInt = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(length + 1, true) : Nd4j.getDataBufferFactory().createInt(length + 1, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(1L, deviceContext.getOldStream());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1(put, AtomicAllocator.getInstance().getPointer(data), AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()), data.length(), AtomicAllocator.getInstance().getPointer(createInt), (float) d);
        AtomicAllocator.getInstance().getAllocationPoint(createInt).tickDeviceWrite();
        int i = createInt.getInt(0L);
        if (i < 2) {
            return null;
        }
        if (num != null && i > num.intValue()) {
            i = num.intValue();
            createInt.put(0L, i);
        }
        DataBuffer createInt2 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(4 + i, false) : Nd4j.getDataBufferFactory().createInt(4 + i, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        AtomicAllocator.getInstance().getAllocationPoint(createInt2).tickHostWrite();
        createInt2.put(0L, i);
        createInt2.put(1L, (int) data.length());
        createInt2.put(2L, Float.floatToIntBits((float) d));
        AtomicAllocator.getInstance().getAllocationPoint(createInt2).tickHostWrite();
        createInt2.put(3L, 0);
        int i2 = length;
        int i3 = 0;
        ArrayList arrayList = new ArrayList();
        do {
            int max = Math.max(1, (int) Math.ceil(i2 / (2.0f * 512)));
            if (length > 1) {
                i3++;
            }
            i2 = max;
        } while (i2 > 1);
        long[] jArr = new long[i3];
        int i4 = 0;
        int i5 = length;
        DataBuffer createDouble = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createDouble(jArr.length, false) : Nd4j.getDataBufferFactory().createDouble(jArr.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        do {
            int max2 = Math.max(1, (int) Math.ceil(i5 / (2.0f * 512)));
            if (max2 > 1) {
                DataBuffer createInt3 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(max2, false) : Nd4j.getDataBufferFactory().createInt(max2, false, Nd4j.getMemoryManager().getCurrentWorkspace());
                arrayList.add(createInt3);
                int i6 = i4;
                i4++;
                jArr[i6] = AtomicAllocator.getInstance().getPointer(createInt3).address();
            }
            i5 = max2;
        } while (i5 > 1);
        AtomicAllocator.getInstance().memcpyBlocking(createDouble, new LongPointer(jArr), jArr.length * 8, 0L);
        put.put(2L, AtomicAllocator.getInstance().getPointer(createDouble));
        DataBuffer createInt4 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(length, true) : Nd4j.getDataBufferFactory().createInt(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP2Int(put, AtomicAllocator.getInstance().getPointer(createInt), length, AtomicAllocator.getInstance().getPointer(createInt4));
        AtomicAllocator.getInstance().getAllocationPoint(createInt4).tickDeviceWrite();
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3(put, AtomicAllocator.getInstance().getPointer(data), AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(createInt4), data.length(), AtomicAllocator.getInstance().getPointer(createInt2));
        AtomicAllocator.getInstance().getAllocationPoint(createInt2).tickDeviceWrite();
        AtomicAllocator.getInstance().getAllocationPoint(data).tickDeviceWrite();
        return Nd4j.createArrayFromShapeBuffer(createInt2, iNDArray.shapeInfoDataBuffer());
    }

    public INDArray thresholdEncode(INDArray iNDArray, double d) {
        return thresholdEncode(iNDArray, d, null);
    }

    public INDArray thresholdDecode(INDArray iNDArray, INDArray iNDArray2) {
        DataBuffer data = iNDArray.data();
        if (data.dataType() != DataType.INT) {
            throw new UnsupportedOperationException();
        }
        long j = data.getInt(0L);
        long j2 = data.getInt(1L);
        if (iNDArray2.length() != j2) {
            throw new ND4JIllegalStateException("originalLength [" + j2 + "] stored in encoded array doesn't match target length [" + iNDArray2.length() + "]");
        }
        DataBuffer data2 = iNDArray2.data();
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        nativeOps.decodeThreshold(this.extraz.get().put(1L, deviceContext.getOldStream()), AtomicAllocator.getInstance().getPointer(data), j, AtomicAllocator.getInstance().getPointer(data2), AtomicAllocator.getInstance().getHostPointer(iNDArray2.shapeInfoDataBuffer()));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getAllocationPoint(data2).tickDeviceWrite();
        return iNDArray2;
    }

    public long bitmapEncode(INDArray iNDArray, INDArray iNDArray2, double d) {
        long length = iNDArray.length();
        if (iNDArray2.data().length() != (length / 16) + 5) {
            throw new ND4JIllegalStateException("Length of target array should be " + ((length / 16) + 5));
        }
        if (iNDArray2.data().dataType() != DataType.INT) {
            throw new ND4JIllegalStateException("Target array should have INT dataType");
        }
        DataBuffer data = iNDArray2.data();
        data.put(0L, (int) length);
        data.put(1L, (int) length);
        data.put(2L, Float.floatToIntBits((float) d));
        data.put(3L, 1);
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray, new INDArray[0]);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        long encodeBitmap = nativeOps.encodeBitmap(this.extraz.get().put(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(iNDArray), prepareAction.getOldStream(), prepareAction.getBufferScalar(), prepareAction.getBufferReduction()}), AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()), length, AtomicAllocator.getInstance().getPointer(data, prepareAction), (float) d);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, iNDArray, new INDArray[0]);
        AtomicAllocator.getInstance().getAllocationPoint(data).tickDeviceWrite();
        return encodeBitmap;
    }

    public INDArray bitmapDecode(INDArray iNDArray, INDArray iNDArray2) {
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray2, new INDArray[0]);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        nativeOps.decodeBitmap(this.extraz.get().put(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(iNDArray2), prepareAction.getOldStream(), prepareAction.getBufferScalar(), prepareAction.getBufferReduction()}), AtomicAllocator.getInstance().getPointer(iNDArray.data(), prepareAction), iNDArray2.length(), AtomicAllocator.getInstance().getPointer(iNDArray2, prepareAction), AtomicAllocator.getInstance().getHostPointer(iNDArray2.shapeInfoDataBuffer()));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, iNDArray2, new INDArray[0]);
        return iNDArray2;
    }

    public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
        if (this.customOps == null) {
            String allCustomOps = nativeOps.getAllCustomOps();
            if (allCustomOps == null || allCustomOps.isEmpty()) {
                log.warn("No customs ops available!");
                this.customOps = Collections.emptyMap();
                return this.customOps;
            }
            HashMap hashMap = new HashMap();
            for (String str : allCustomOps.split(";")) {
                if (str != null && !str.isEmpty()) {
                    String[] split = str.split(":");
                    hashMap.put(split[0], CustomOpDescriptor.builder().hash(Long.valueOf(split[1]).longValue()).numInputs(Integer.valueOf(split[2]).intValue()).numOutputs(Integer.valueOf(split[3]).intValue()).allowsInplace(Integer.valueOf(split[4]).intValue() == 1).numTArgs(Integer.valueOf(split[5]).intValue()).numIArgs(Integer.valueOf(split[6]).intValue()).build());
                }
            }
            this.customOps = Collections.unmodifiableMap(hashMap);
        }
        return this.customOps;
    }

    protected LongShapeDescriptor getShapeFromPointer(LongPointer longPointer) {
        long[] jArr = new long[(((int) longPointer.get(0L)) * 2) + 4];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = longPointer.get(i);
        }
        return LongShapeDescriptor.fromShape(Shape.shape(jArr), Shape.stride(jArr), Shape.elementWiseStride(jArr), Shape.order(jArr), ArrayOptionsHelper.dataType(jArr), ArrayOptionsHelper.arrayType(jArr) == ArrayType.EMPTY);
    }

    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp customOp) {
        if (customOp == null) {
            throw new NullPointerException("op is marked @NonNull but is null");
        }
        Nd4j.getExecutioner().commit();
        customOp.opName().toLowerCase();
        long opHash = customOp.opHash();
        ArrayList arrayList = new ArrayList();
        if (customOp.numInputArguments() < 1 && customOp.getDescriptor().getNumInputs() != -2) {
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: number of input args was 0", customOp.getClass().getName());
            }
            return Collections.emptyList();
        }
        PointerPointer pointerPointer = new PointerPointer(customOp.inputArguments().length * 2);
        PointerPointer pointerPointer2 = new PointerPointer(customOp.inputArguments().length);
        int i = 0;
        for (INDArray iNDArray : customOp.inputArguments()) {
            if (!iNDArray.isEmpty()) {
                pointerPointer.put(i, iNDArray.data().addressPointer());
                pointerPointer.put(i + customOp.inputArguments().length, AtomicAllocator.getInstance().getPointer(iNDArray.data()));
            }
            int i2 = i;
            i++;
            pointerPointer2.put(i2, iNDArray.shapeInfoDataBuffer().addressPointer());
        }
        LongPointer longPointer = customOp.iArgs().length > 0 ? new LongPointer(customOp.iArgs().length) : null;
        int i3 = 0;
        for (long j : customOp.iArgs()) {
            int i4 = i3;
            i3++;
            longPointer.put(i4, j);
        }
        DoublePointer doublePointer = customOp.tArgs().length > 0 ? new DoublePointer(customOp.tArgs().length) : null;
        BooleanPointer booleanPointer = customOp.bArgs().length > 0 ? new BooleanPointer(customOp.bArgs().length) : null;
        int i5 = 0;
        for (boolean z : customOp.bArgs()) {
            int i6 = i5;
            i5++;
            booleanPointer.put(i6, z);
        }
        int i7 = 0;
        int length = customOp.tArgs().length;
        for (int i8 = 0; i8 < length; i8++) {
            int i9 = i7;
            i7++;
            doublePointer.put(i9, (float) r0[i8]);
        }
        OpaqueShapeList calculateOutputShapes2 = nativeOps.calculateOutputShapes2((PointerPointer) null, opHash, pointerPointer, pointerPointer2, customOp.inputArguments().length, doublePointer, customOp.tArgs().length, longPointer, customOp.iArgs().length, booleanPointer, customOp.numBArguments());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (calculateOutputShapes2 == null) {
            throw new RuntimeException();
        }
        for (int i10 = 0; i10 < nativeOps.getShapeListSize(calculateOutputShapes2); i10++) {
            arrayList.add(getShapeFromPointer(new PagedPointer(nativeOps.getShape(calculateOutputShapes2, i10)).asLongPointer()));
        }
        nativeOps.deleteShapeList(calculateOutputShapes2);
        return arrayList;
    }

    public INDArray[] exec(CustomOp customOp) {
        Nd4j.getExecutioner().commit();
        if (customOp.numOutputArguments() == 0 && !customOp.isInplaceCall()) {
            try {
                List<LongShapeDescriptor> calculateOutputShape = calculateOutputShape(customOp);
                if (calculateOutputShape.isEmpty()) {
                    throw new ND4JIllegalStateException("Op name " + customOp.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
                }
                Iterator<LongShapeDescriptor> it = calculateOutputShape.iterator();
                while (it.hasNext()) {
                    customOp.addOutputArgument(new INDArray[]{Nd4j.create(it.next())});
                }
            } catch (Exception e) {
                throw new ND4JIllegalStateException("Op name " + customOp.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
            }
        }
        AtomicAllocator.getInstance().getDeviceContext();
        String opName = customOp.opName();
        try {
            CudaOpContext cudaOpContext = (CudaOpContext) buildContext();
            Throwable th = null;
            try {
                try {
                    cudaOpContext.markInplace(customOp.isInplaceCall());
                    cudaOpContext.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
                    cudaOpContext.setInputArrays(customOp.inputArguments());
                    cudaOpContext.setOutputArrays(customOp.outputArguments());
                    cudaOpContext.setBArguments(customOp.bArgs());
                    cudaOpContext.setIArguments(customOp.iArgs());
                    cudaOpContext.setTArguments(customOp.tArgs());
                    INDArray[] exec = exec(customOp, cudaOpContext);
                    Pair<Long, Long> rngStates = cudaOpContext.getRngStates();
                    Nd4j.getRandom().setStates(((Long) rngStates.getFirst()).longValue(), ((Long) rngStates.getSecond()).longValue());
                    if (cudaOpContext != null) {
                        if (0 != 0) {
                            try {
                                cudaOpContext.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            cudaOpContext.close();
                        }
                    }
                    return exec;
                } finally {
                }
            } catch (Throwable th3) {
                if (cudaOpContext != null) {
                    if (th != null) {
                        try {
                            cudaOpContext.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        cudaOpContext.close();
                    }
                }
                throw th3;
            }
        } catch (Exception e2) {
            throw new RuntimeException("Op [" + opName + "] execution failed", e2);
        } catch (ND4JOpProfilerException e3) {
            throw e3;
        }
    }

    public void enableDebugMode(boolean z) {
        this.debug.set(z);
        nativeOps.enableDebugMode(z);
    }

    public void enableVerboseMode(boolean z) {
        this.verbose.set(z);
        nativeOps.enableVerboseMode(z);
    }

    public void registerGraph(long j, Pointer pointer) {
        nativeOps.registerGraph((PointerPointer) null, j, pointer);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public Map<String, INDArray> executeGraph(long j, @NonNull Map<String, INDArray> map, @NonNull Map<String, Integer> map2) {
        if (map == null) {
            throw new NullPointerException("map is marked @NonNull but is null");
        }
        if (map2 == null) {
            throw new NullPointerException("reverseMap is marked @NonNull but is null");
        }
        Nd4j.getExecutioner().commit();
        PointerPointer pointerPointer = new PointerPointer(map.size() * 2);
        PointerPointer pointerPointer2 = new PointerPointer(map.size() * 2);
        IntPointer intPointer = new IntPointer(map.size());
        int i = 0;
        Iterator it = new ArrayList(map.keySet()).iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            INDArray iNDArray = map.get(str);
            pointerPointer.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArray));
            pointerPointer2.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()));
            intPointer.put(i, map2.get(str).intValue());
            i++;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        OpaqueVariablesSet executeStoredGraph = nativeOps.executeStoredGraph((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpStatus byNumber = OpStatus.byNumber(nativeOps.getVariablesSetStatus(executeStoredGraph));
        if (byNumber != OpStatus.ND4J_STATUS_OK) {
            throw new ND4JIllegalStateException("Op execution failed: " + byNumber);
        }
        for (int i2 = 0; i2 < nativeOps.getVariablesSetSize(executeStoredGraph); i2++) {
            OpaqueVariable variable = nativeOps.getVariable(executeStoredGraph, i2);
            nativeOps.getVariableId(variable);
            nativeOps.getVariableIndex(variable);
            LongPointer variableShape = nativeOps.getVariableShape(variable);
            Pointer variableBuffer = nativeOps.getVariableBuffer(variable);
            long[] jArr = new long[(((int) variableShape.get(0L)) * 2) + 4];
            for (int i3 = 0; i3 < jArr.length; i3++) {
                jArr[i3] = variableShape.get(i3);
            }
            INDArray create = Nd4j.create(Shape.shapeOf(jArr), Shape.stridesOf(jArr), 0L, Shape.order(jArr));
            Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(create), variableBuffer, ArrayUtil.prod(r0) * Nd4j.sizeOfDataType());
            AtomicAllocator.getInstance().getAllocationPoint(create).tickHostWrite();
            linkedHashMap.put(nativeOps.getVariableName(variable), create);
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        nativeOps.deleteVariablesSet(executeStoredGraph);
        return linkedHashMap;
    }

    public void forgetGraph(long j) {
        nativeOps.unregisterGraph((PointerPointer) null, j);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public void setElementsThreshold(int i) {
        nativeOps.setElementThreshold(i);
    }

    public void setTadThreshold(int i) {
        nativeOps.setTADThreshold(i);
    }

    public OpExecutioner.ExecutionerType type() {
        return OpExecutioner.ExecutionerType.CUDA;
    }

    public String getString(Utf8Buffer utf8Buffer, long j) {
        return new Nd4jCuda.utf8string((Pointer) new PagedPointer(utf8Buffer.indexer().get(j)))._buffer().capacity(r0._length()).getString();
    }

    public boolean isExperimentalMode() {
        return this.experimentalMode.get();
    }

    public void scatterUpdate(ScatterUpdate.UpdateOp updateOp, @NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, @NonNull INDArray iNDArray3, @NonNull int[] iArr) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("indices is marked @NonNull but is null");
        }
        if (iNDArray3 == null) {
            throw new NullPointerException("updates is marked @NonNull but is null");
        }
        if (iArr == null) {
            throw new NullPointerException("axis is marked @NonNull but is null");
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray, iNDArray2, iNDArray3);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(iNDArray, iArr);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(iNDArray3, iArr);
        if (((DataBuffer) tADOnlyShapeInfo2.getSecond()).length() != iNDArray2.length()) {
            throw new IllegalStateException("Number of updates doesn't match number of indices. Bad dimensions used?");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        nativeOps.scatterUpdate(this.extraz.get().put(new Pointer[]{null, prepareAction.getOldStream()}), updateOp.ordinal(), (int) iNDArray2.length(), (Pointer) null, AtomicAllocator.getInstance().getHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst()), (LongPointer) null, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst()), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond()), (Pointer) null, AtomicAllocator.getInstance().getHostPointer((DataBuffer) tADOnlyShapeInfo2.getFirst()), (LongPointer) null, AtomicAllocator.getInstance().getPointer(iNDArray3, prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst()), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond()), (IntPointer) null, AtomicAllocator.getInstance().getPointer(iNDArray2, prepareAction));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, iNDArray, iNDArray2, iNDArray3);
    }

    public OpContext buildContext() {
        return new CudaOpContext();
    }

    public INDArray[] exec(CustomOp customOp, OpContext opContext) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(customOp);
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        ((CudaOpContext) opContext).setCudaStream(deviceContext.getOldStream(), deviceContext.getBufferReduction(), deviceContext.getBufferAllocation());
        int execCustomOp2 = nativeOps.execCustomOp2((PointerPointer) null, customOp.opHash(), opContext.contextPointer());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (execCustomOp2 != 0) {
            throw new RuntimeException("Op [" + customOp.opName() + "] execution failed");
        }
        for (INDArray iNDArray : customOp.outputArguments()) {
            AtomicAllocator.getInstance().registerAction(deviceContext, iNDArray, new INDArray[0]);
        }
        AtomicAllocator.getInstance().registerAction(deviceContext, null, customOp.inputArguments());
        profilingConfigurableHookOut(customOp, profilingConfigurableHookIn);
        return opContext.getOutputArrays().isEmpty() ? new INDArray[0] : (INDArray[]) opContext.getOutputArrays().toArray(new INDArray[opContext.getOutputArrays().size()]);
    }

    public INDArrayStatistics inspectArray(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        Nd4jCuda.DebugInfo debugInfo = new Nd4jCuda.DebugInfo();
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        AtomicAllocator.getInstance().synchronizeHostData(iNDArray);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        nativeOps.inspectArray(this.extraz.get().put(new Pointer[]{null, deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial()}), AtomicAllocator.getInstance().getHostPointer(iNDArray), AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(iNDArray, deviceContext), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer()), debugInfo);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        return INDArrayStatistics.builder().minValue(debugInfo._minValue()).maxValue(debugInfo._maxValue()).meanValue(debugInfo._meanValue()).stdDevValue(debugInfo._stdDevValue()).countInf(debugInfo._infCount()).countNaN(debugInfo._nanCount()).countNegative(debugInfo._negativeCount()).countPositive(debugInfo._positiveCount()).countZero(debugInfo._zeroCount()).build();
    }

    public DataBuffer createShapeInfo(long[] jArr, long[] jArr2, long j, char c, DataType dataType, boolean z) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer shapeBuffer = nativeOps.shapeBuffer(jArr.length, new LongPointer(jArr), new LongPointer(jArr2), dataType.toInt(), c, j, z);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer cudaLongDataBuffer = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(shapeBuffer), nativeOps.getConstantDataBufferSpecial(shapeBuffer), Shape.shapeInfoLength(jArr.length));
        nativeOps.deleteShapeBuffer(shapeBuffer);
        return cudaLongDataBuffer;
    }

    public TadPack tadShapeInfoAndOffsets(INDArray iNDArray, int[] iArr) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueTadPack tadOnlyShapeInfo = nativeOps.tadOnlyShapeInfo(iNDArray.shapeInfoDataBuffer().addressPointer(), new IntPointer(iArr), iArr.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer cudaLongDataBuffer = new CudaLongDataBuffer((Pointer) nativeOps.getPrimaryShapeInfo(tadOnlyShapeInfo), (Pointer) nativeOps.getSpecialShapeInfo(tadOnlyShapeInfo), nativeOps.getShapeInfoLength(tadOnlyShapeInfo));
        CudaLongDataBuffer cudaLongDataBuffer2 = new CudaLongDataBuffer((Pointer) nativeOps.getPrimaryOffsets(tadOnlyShapeInfo), (Pointer) nativeOps.getSpecialOffsets(tadOnlyShapeInfo), nativeOps.getNumberOfTads(tadOnlyShapeInfo));
        nativeOps.deleteTadPack(tadOnlyShapeInfo);
        return new TadPack(cudaLongDataBuffer, cudaLongDataBuffer2);
    }

    public DataBuffer createConstantBuffer(long[] jArr, DataType dataType) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer constantBufferLong = nativeOps.constantBufferLong(dataType.toInt(), new LongPointer(jArr), jArr.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        DataBuffer createBuffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(constantBufferLong), nativeOps.getConstantDataBufferSpecial(constantBufferLong), jArr.length, dataType);
        createBuffer.setConstant(true);
        return createBuffer;
    }

    public DataBuffer createConstantBuffer(double[] dArr, DataType dataType) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer constantBufferDouble = nativeOps.constantBufferDouble(dataType.toInt(), new DoublePointer(dArr), dArr.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        DataBuffer createBuffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(constantBufferDouble), nativeOps.getConstantDataBufferSpecial(constantBufferDouble), dArr.length, dataType);
        createBuffer.setConstant(true);
        return createBuffer;
    }

    public String runLightBenchmarkSuit(boolean z) {
        return nativeOps.runLightBenchmarkSuit(z);
    }

    public String runFullBenchmarkSuit(boolean z) {
        return nativeOps.runFullBenchmarkSuit(z);
    }
}
