package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
import org.nd4j.linalg.api.ops.BaseReduceBoolOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.TadPack;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.api.shape.options.ArrayType;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCuda;
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueConstantShapeBuffer;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariablesSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.class */
public class CudaExecutioner extends DefaultOpExecutioner {
    private static final Logger log = LoggerFactory.getLogger(CudaExecutioner.class);
    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected static TADManager tadManager = new DeviceTADManager();
    protected volatile transient Properties properties;
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal<>();
    protected ThreadLocal<String> lastOp = new ThreadLocal<>();
    protected Map<String, CustomOpDescriptor> customOps = null;
    protected AtomicBoolean experimentalMode = new AtomicBoolean(false);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$ops$Op$Type = new int[Op.Type.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.BROADCAST.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.BROADCAST_BOOL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_LONG.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_BOOL.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_FLOAT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_SAME.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.SCALAR.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.SCALAR_BOOL.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_BOOL.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.PAIRWISE_BOOL.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_ANY.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_FLOAT.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_SAME.ordinal()] = 13;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_STRICT.ordinal()] = 14;
            } catch (NoSuchFieldError e14) {
            }
        }
    }

    public CudaExecutioner() {
        this.experimentalMode.set(nativeOps.isExperimentalEnabled());
    }

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public String getLastOp() {
        return this.lastOp.get();
    }

    public INDArray exec(BroadcastOp broadcastOp) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(broadcastOp, new DataBuffer[0]);
        checkForCompression(broadcastOp);
        int[] intVector = broadcastOp.dimensions().toIntVector();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(broadcastOp.opName());
        }
        Pointer retrieveHostPointer = broadcastOp.y() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = broadcastOp.z() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.z().shapeInfoDataBuffer());
        OpaqueDataBuffer opaqueDataBuffer = broadcastOp.x() == null ? null : ((BaseCudaDataBuffer) broadcastOp.x().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = broadcastOp.y() == null ? null : ((BaseCudaDataBuffer) broadcastOp.y().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = broadcastOp.z() == null ? null : ((BaseCudaDataBuffer) broadcastOp.z().data()).getOpaqueDataBuffer();
        LongPointer pointer = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), deviceContext);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), intVector);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer2 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), deviceContext);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), deviceContext);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(broadcastOp.z(), intVector);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(broadcastOp.x().shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer2, pointer3, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), deviceContext), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), deviceContext)});
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(intVector), deviceContext);
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[broadcastOp.getOpType().ordinal()]) {
            case 1:
                nativeOps.execBroadcast(put, broadcastOp.opNum(), opaqueDataBuffer, AtomicAllocator.getInstance().getHostPointer(broadcastOp.x().shapeInfoDataBuffer()), pointer, opaqueDataBuffer2, AtomicAllocator.getInstance().getHostPointer(broadcastOp.y().shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, AtomicAllocator.getInstance().getHostPointer(broadcastOp.z().shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), deviceContext), ((BaseCudaDataBuffer) broadcastOp.dimensions().data()).getOpaqueDataBuffer(), AtomicAllocator.getInstance().getHostPointer(broadcastOp.dimensions().shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(broadcastOp.dimensions().shapeInfoDataBuffer(), deviceContext));
                break;
            case 2:
                nativeOps.execBroadcastBool(put, broadcastOp.opNum(), opaqueDataBuffer, AtomicAllocator.getInstance().getHostPointer(broadcastOp.x().shapeInfoDataBuffer()), pointer, opaqueDataBuffer2, AtomicAllocator.getInstance().getHostPointer(broadcastOp.y().shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, AtomicAllocator.getInstance().getHostPointer(broadcastOp.z().shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), deviceContext), (Pointer) null, ((BaseCudaDataBuffer) broadcastOp.dimensions().data()).getOpaqueDataBuffer(), AtomicAllocator.getInstance().getHostPointer(broadcastOp.dimensions().shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(broadcastOp.dimensions().shapeInfoDataBuffer(), deviceContext));
                break;
            default:
                throw new UnsupportedOperationException("Unknown op type: " + broadcastOp.getOpType());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(broadcastOp, null, profilingConfigurableHookIn);
        return broadcastOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray naiveExec(ReduceOp reduceOp, int... iArr) {
        DataType dataType;
        long profilingConfigurableHookIn = profilingConfigurableHookIn(reduceOp, new DataBuffer[0]);
        if ((reduceOp instanceof BaseReduceOp) && ((BaseReduceOp) reduceOp).isEmptyReduce()) {
            if (reduceOp.z() == null) {
                reduceOp.setZ(reduceOp.x().dup());
                return reduceOp.z();
            }
            Preconditions.checkState(reduceOp.x().equalShapes(reduceOp.z()), "For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", reduceOp.x(), reduceOp.z());
            reduceOp.z().assign(reduceOp.x());
            return reduceOp.z();
        }
        INDArray z = reduceOp.z();
        checkForCompression(reduceOp);
        reduceOp.validateDataTypes((OpContext) null);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= reduceOp.x().rank() && iArr[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(iArr) + " contains element that higher then rank of op.X: [" + reduceOp.x().rank() + "]");
            }
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(reduceOp.opName());
        }
        Pointer retrieveHostPointer = reduceOp.x() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = reduceOp.y() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = reduceOp.z() == null ? null : AddressRetriever.retrieveHostPointer(reduceOp.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(reduceOp.x(), iArr);
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        LongPointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), deviceContext);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer pointer2 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, deviceContext);
        LongPointer pointer3 = AtomicAllocator.getInstance().getPointer(reduceOp.x().shapeInfoDataBuffer(), deviceContext);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(reduceOp.x().shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer, pointer2});
        Pointer pointer4 = null;
        Pointer pointer5 = null;
        if (reduceOp.y() != null) {
            if (iArr.length != 0 && (!(iArr.length == 1 && iArr[0] == Integer.MAX_VALUE) && reduceOp.x().tensorAlongDimension(0L, iArr).length() == reduceOp.y().length())) {
                DataBuffer constantBuffer = Nd4j.getConstantHandler().getConstantBuffer(new int[]{0, 0}, DataType.LONG);
                pointer4 = constantBuffer == null ? null : AtomicAllocator.getInstance().getPointer(constantBuffer, deviceContext);
                pointer5 = AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), deviceContext);
                put.put(12L, AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), deviceContext));
                put.put(13L, (Pointer) null);
            } else {
                if (!reduceOp.isComplexAccumulation() && reduceOp.x().length() != reduceOp.y().length()) {
                    throw new ND4JIllegalStateException("Op.X [" + reduceOp.x().length() + "] and Op.Y [" + reduceOp.y().length() + "] lengths should match");
                }
                if (!reduceOp.z().isScalar()) {
                    Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(reduceOp.y(), iArr);
                    pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), deviceContext);
                    DataBuffer dataBuffer2 = (DataBuffer) tADOnlyShapeInfo2.getSecond();
                    pointer4 = dataBuffer2 == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer2, deviceContext);
                    put.put(12L, pointer5);
                    put.put(13L, pointer4);
                }
            }
        }
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
            case 3:
            case 4:
                dataType = reduceOp.x().dataType();
                break;
            default:
                dataType = reduceOp.z().dataType();
                break;
        }
        Pointer pointer6 = reduceOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(reduceOp.extraArgsDataBuff(dataType), deviceContext) : null;
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), deviceContext);
        OpaqueDataBuffer opaqueDataBuffer = reduceOp.x() == null ? null : ((BaseCudaDataBuffer) reduceOp.x().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = reduceOp.y() == null ? null : ((BaseCudaDataBuffer) reduceOp.y().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = reduceOp.z() == null ? null : ((BaseCudaDataBuffer) reduceOp.z().data()).getOpaqueDataBuffer();
        if (reduceOp instanceof Variance) {
            if (z.isScalar()) {
                nativeOps.execSummaryStatsScalar(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()), ((Variance) reduceOp).isBiasCorrected());
            } else {
                nativeOps.execSummaryStatsTad(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), deviceContext), ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, ((Variance) reduceOp).isBiasCorrected(), pointer, (LongPointer) pointer2);
            }
        } else if (reduceOp.y() != null) {
            if (reduceOp.isComplexAccumulation()) {
                nativeOps.execReduce3All(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), deviceContext), ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, pointer, new LongPointerWrapper(pointer2), (LongPointer) pointer5, new LongPointerWrapper(pointer4));
            } else if (z.isScalar()) {
                nativeOps.execReduce3Scalar(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), deviceContext));
            } else {
                nativeOps.execReduce3Tad(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(reduceOp.y().shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), deviceContext), ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, pointer, (LongPointer) pointer2, (LongPointer) pointer5, (LongPointer) pointer4);
            }
        } else if (z.isScalar()) {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                case 3:
                    nativeOps.execReduceLong(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()));
                    break;
                case 4:
                    nativeOps.execReduceBool(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()));
                    break;
                case Nd4jCuda.FLOAT32 /* 5 */:
                    nativeOps.execReduceFloat(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()));
                    break;
                case Nd4jCuda.DOUBLE /* 6 */:
                    nativeOps.execReduceSame(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer()));
                    break;
                default:
                    throw new UnsupportedOperationException();
            }
        } else {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                case 3:
                    nativeOps.execReduceLong2(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), deviceContext), ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case 4:
                    nativeOps.execReduceBool2(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), deviceContext), ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case Nd4jCuda.FLOAT32 /* 5 */:
                    nativeOps.execReduceFloat2(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), deviceContext), ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case Nd4jCuda.DOUBLE /* 6 */:
                    nativeOps.execReduceSame2(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer3, pointer6, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(reduceOp.z().shapeInfoDataBuffer(), deviceContext), ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                default:
                    throw new UnsupportedOperationException();
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(reduceOp, null, profilingConfigurableHookIn);
        return reduceOp.z();
    }

    public INDArray exec(Variance variance) {
        return exec((ReduceOp) variance);
    }

    public INDArray exec(ReduceOp reduceOp) {
        INDArray create;
        checkForCompression(reduceOp);
        if ((reduceOp instanceof BaseReduceOp) && ((BaseReduceOp) reduceOp).isEmptyReduce()) {
            if (reduceOp.z() == null) {
                reduceOp.setZ(reduceOp.x().dup());
                return reduceOp.z();
            }
            Preconditions.checkState(reduceOp.x().equalShapes(reduceOp.z()), "For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", reduceOp.x(), reduceOp.z());
            reduceOp.z().assign(reduceOp.x());
            return reduceOp.z();
        }
        int[] intVector = reduceOp.dimensions().toIntVector();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Shape.getMaxShape(new INDArray[]{reduceOp.x(), reduceOp.y()});
        boolean z = Shape.wholeArrayDimension(intVector) || reduceOp.x().rank() == intVector.length || intVector.length == 0;
        long[] reductionShape = Shape.reductionShape(reduceOp.y() == null ? reduceOp.x() : reduceOp.x().length() > reduceOp.y().length() ? reduceOp.x() : reduceOp.y(), intVector, true, reduceOp.isKeepDims());
        if (reduceOp.x().isVector() && reduceOp.x().length() == ArrayUtil.prod(reductionShape) && ArrayUtil.prodLong(reductionShape) > 1 && reduceOp.y() == null) {
            return reduceOp.noOp();
        }
        DataType resultType = reduceOp.resultType();
        if (reduceOp.z() == null || reduceOp.z() == reduceOp.x()) {
            if (reduceOp.isComplexAccumulation()) {
                create = Nd4j.createUninitialized(resultType, new long[]{reduceOp.x().tensorsAlongDimension(intVector), reduceOp.y().tensorsAlongDimension(intVector)});
            } else {
                if (reduceOp.y() != null) {
                    if (reduceOp.x().length() == reduceOp.y().length()) {
                        if (!z && reduceOp.x().tensorsAlongDimension(intVector) != reduceOp.y().tensorsAlongDimension(intVector)) {
                            throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(reduceOp.x().shape()) + ", y shape = " + Arrays.toString(reduceOp.y().shape()) + ", dimension = " + Arrays.toString(intVector) + ")");
                        }
                    } else {
                        if (intVector.length == 0) {
                            throw new ND4JIllegalStateException("TAD vs TAD comparison requires dimension (or other comparison mode was supposed to be used?)");
                        }
                        long length = reduceOp.x().length() / reduceOp.x().tensorsAlongDimension(intVector);
                        if (length != reduceOp.y().length()) {
                            throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + length + ", y size = " + reduceOp.y().length());
                        }
                    }
                }
                create = Nd4j.create(resultType, reductionShape);
            }
            reduceOp.setZ(create);
        } else {
            if (reduceOp.z().length() != (reductionShape.length == 0 ? 1L : ArrayUtil.prodLong(reductionShape))) {
                throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(reduceOp.z().shape()) + "] doesn't match expected [" + Arrays.toString(reductionShape) + "]");
            }
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(reduceOp, new DataBuffer[0]);
        naiveExec(reduceOp, intVector);
        profilingConfigurableHookOut(reduceOp, null, profilingConfigurableHookIn);
        return reduceOp.z();
    }

    public INDArray exec(IndexAccumulation indexAccumulation) {
        int[] normalizeAxis = Shape.normalizeAxis(indexAccumulation.x().rank(), indexAccumulation.dimensions().toIntVector());
        if (indexAccumulation.x().isEmpty()) {
            for (int i : normalizeAxis) {
                Preconditions.checkArgument(indexAccumulation.x().shape()[i] != 0, "IndexReduce can't be issued along axis with 0 in shape");
            }
        }
        if (indexAccumulation.z() == null) {
            indexAccumulation.setZ(Nd4j.createUninitialized(DataType.LONG, Shape.reductionShape(indexAccumulation.x(), normalizeAxis, true, indexAccumulation.isKeepDims())));
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(indexAccumulation, new DataBuffer[0]);
        checkForCompression(indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (indexAccumulation.x().isVector() && indexAccumulation.x().length() == indexAccumulation.z().length()) {
            return indexAccumulation.x();
        }
        if (indexAccumulation.z().isEmpty()) {
            return indexAccumulation.z();
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(indexAccumulation.opName());
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        Pointer retrieveHostPointer = indexAccumulation.x() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = indexAccumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = indexAccumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.z().shapeInfoDataBuffer());
        LongPointer pointer = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), deviceContext);
        LongPointer pointer2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), deviceContext);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), normalizeAxis);
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), deviceContext);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer3, dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, deviceContext)});
        Pointer pointer4 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(indexAccumulation.x().dataType()), deviceContext) : null;
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(normalizeAxis), deviceContext);
        OpaqueDataBuffer opaqueDataBuffer = indexAccumulation.x() == null ? null : ((BaseCudaDataBuffer) indexAccumulation.x().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = indexAccumulation.y() == null ? null : ((BaseCudaDataBuffer) indexAccumulation.y().data()).getOpaqueDataBuffer();
        nativeOps.execIndexReduce(put, indexAccumulation.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer, pointer4, indexAccumulation.z() == null ? null : ((BaseCudaDataBuffer) indexAccumulation.z().data()).getOpaqueDataBuffer(), (LongPointer) retrieveHostPointer3, pointer2, ((BaseCudaDataBuffer) indexAccumulation.dimensions().data()).getOpaqueDataBuffer(), indexAccumulation.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(indexAccumulation, null, profilingConfigurableHookIn);
        return indexAccumulation.z();
    }

    public INDArray exec(Op op) {
        return exec(op, (OpContext) null);
    }

    public INDArray exec(Op op, OpContext opContext) {
        checkForCompression(op);
        if (op instanceof TransformOp) {
            invoke((TransformOp) op, opContext);
        } else if (op instanceof ReduceOp) {
            ReduceOp reduceOp = (ReduceOp) op;
            invoke(reduceOp, opContext, reduceOp.dimensions().toIntVector());
        } else if (op instanceof ScalarOp) {
            invoke((ScalarOp) op, opContext);
        } else if (op instanceof BroadcastOp) {
            invoke((BroadcastOp) op, opContext);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation) op;
            invoke(indexAccumulation, opContext, indexAccumulation.dimensions().toIntVector());
        } else if (op instanceof RandomOp) {
            exec((RandomOp) op, opContext, Nd4j.getRandom());
        } else if (op instanceof CustomOp) {
            exec((CustomOp) op, opContext);
        }
        return op.z();
    }

    public TransformOp execAndReturn(TransformOp transformOp) {
        checkForCompression(transformOp);
        invoke(transformOp, (OpContext) null);
        return transformOp;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(BroadcastOp broadcastOp, OpContext opContext) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(broadcastOp, new DataBuffer[0]);
        INDArray x = getX(broadcastOp, opContext);
        INDArray y = getY(broadcastOp, opContext);
        INDArray z = getZ(broadcastOp, opContext);
        checkForCompression(broadcastOp);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(broadcastOp.opName());
        }
        LongPointer pointer = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), deviceContext);
        Pointer retrieveHostPointer = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(x, broadcastOp.getDimension());
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer2 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), deviceContext);
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), deviceContext);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(z, broadcastOp.getDimension());
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer2, pointer3, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), deviceContext), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), deviceContext)});
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), deviceContext);
        LongPointer pointer5 = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), deviceContext);
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(broadcastOp.getDimension()), deviceContext);
        OpaqueDataBuffer opaqueDataBuffer = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[broadcastOp.getOpType().ordinal()]) {
            case 1:
                nativeOps.execBroadcast(put, broadcastOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer5, ((BaseCudaDataBuffer) broadcastOp.dimensions().data()).getOpaqueDataBuffer(), broadcastOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                break;
            case 2:
                nativeOps.execBroadcastBool(put, broadcastOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer5, (Pointer) null, ((BaseCudaDataBuffer) broadcastOp.dimensions().data()).getOpaqueDataBuffer(), broadcastOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                break;
            default:
                throw new UnsupportedOperationException("Unknown opType: " + broadcastOp.getOpType());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(broadcastOp, opContext, profilingConfigurableHookIn);
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(IndexAccumulation indexAccumulation, OpContext opContext, int[] iArr) {
        INDArray x = getX(indexAccumulation, opContext);
        INDArray y = getY(indexAccumulation, opContext);
        INDArray z = getZ(indexAccumulation, opContext);
        int[] normalizeAxis = Shape.normalizeAxis(x.rank(), iArr);
        if ((normalizeAxis == null || (normalizeAxis.length == 1 && normalizeAxis[0] == Integer.MAX_VALUE)) && (z == x || z == null)) {
            z = Nd4j.createUninitialized(DataType.LONG, new long[0], 'c');
            setZ(z, indexAccumulation, opContext);
        }
        long[] reductionShape = Shape.reductionShape(x, normalizeAxis, true, indexAccumulation.isKeepDims());
        if (z == null || x == z) {
            INDArray createUninitialized = Nd4j.createUninitialized(DataType.LONG, reductionShape);
            setZ(createUninitialized, indexAccumulation, opContext);
            z = createUninitialized;
        } else if (!Arrays.equals(reductionShape, z.shape())) {
            throw new IllegalStateException("Z array shape does not match expected return type for op " + indexAccumulation + ": expected shape " + Arrays.toString(reductionShape) + ", z.shape()=" + Arrays.toString(z.shape()));
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(indexAccumulation, new DataBuffer[0]);
        checkForCompression(indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(indexAccumulation.opName());
        }
        CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
        if (normalizeAxis != null) {
            for (int i = 0; i < normalizeAxis.length; i++) {
                if (normalizeAxis[i] >= x.rank() && normalizeAxis[i] != Integer.MAX_VALUE) {
                    throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(normalizeAxis) + " contains element that higher then rank of op.X: [" + x.rank() + "]");
                }
            }
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        LongPointer pointer = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), deviceContext);
        Pointer pointer2 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(x.dataType()), deviceContext) : null;
        Pointer retrieveHostPointer = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        int[] iArr2 = normalizeAxis;
        if (iArr2 == null) {
            iArr2 = new int[]{0};
        }
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(x, iArr2);
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), deviceContext);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer pointer4 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, deviceContext);
        LongPointer pointer5 = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), deviceContext);
        OpaqueDataBuffer opaqueDataBuffer = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer3, pointer4});
        if (z.isScalar() || normalizeAxis == null || normalizeAxis[0] == Integer.MAX_VALUE) {
            nativeOps.execIndexReduceScalar(put, indexAccumulation.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer, pointer2, opaqueDataBuffer2, (LongPointer) retrieveHostPointer3, pointer5);
        } else {
            if (normalizeAxis != null && normalizeAxis.length > 1) {
                Arrays.sort(normalizeAxis);
            }
            AtomicAllocator.getInstance().getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(normalizeAxis));
            nativeOps.execIndexReduce(put, indexAccumulation.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer, pointer2, opaqueDataBuffer2, (LongPointer) retrieveHostPointer3, pointer5, ((BaseCudaDataBuffer) indexAccumulation.dimensions().data()).getOpaqueDataBuffer(), indexAccumulation.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(indexAccumulation, opContext, profilingConfigurableHookIn);
        return null;
    }

    protected CudaContext invoke(ReduceOp reduceOp, OpContext opContext, int[] iArr) {
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        INDArray x = getX(reduceOp, opContext);
        INDArray y = getY(reduceOp, opContext);
        INDArray z = getZ(reduceOp, opContext);
        if ((reduceOp instanceof BaseReduceOp) && ((BaseReduceOp) reduceOp).isEmptyReduce()) {
            if (z == null) {
                reduceOp.setZ(x.dup());
                return deviceContext;
            }
            Preconditions.checkState(x.equalShapes(z), "For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", x, z);
            z.assign(x);
            return deviceContext;
        }
        if ((reduceOp instanceof BaseReduceBoolOp) && x.isEmpty() && (iArr == null || (iArr.length == 1 && iArr[0] == Integer.MAX_VALUE))) {
            if (z == null) {
                reduceOp.setZ(Nd4j.scalar(((BaseReduceBoolOp) reduceOp).emptyValue()));
            } else {
                z.assign(((BaseReduceBoolOp) reduceOp).emptyValue());
            }
            return deviceContext;
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(reduceOp, new DataBuffer[0]);
        checkForCompression(reduceOp);
        int[] normalizeAxis = Shape.normalizeAxis(x.rank(), iArr);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (normalizeAxis == null) {
            normalizeAxis = new int[]{Nd4jCuda.MAX_DIMENSION};
        }
        if (normalizeAxis != null && normalizeAxis.length > 1) {
            Arrays.sort(normalizeAxis);
        }
        for (int i = 0; i < normalizeAxis.length; i++) {
            if (normalizeAxis[i] >= x.rank() && normalizeAxis[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(normalizeAxis) + " contains element that higher then rank of op.X: [" + x.rank() + "]");
            }
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(reduceOp.opName());
        }
        Pair makePair = x.isEmpty() ? Pair.makePair(x.data(), (Object) null) : tadManager.getTADOnlyShapeInfo(x, normalizeAxis);
        Pointer retrieveHostPointer = AddressRetriever.retrieveHostPointer((DataBuffer) makePair.getFirst());
        LongPointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) makePair.getFirst(), deviceContext);
        DataBuffer dataBuffer = x.isEmpty() ? null : (DataBuffer) makePair.getSecond();
        Pointer pointer2 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, deviceContext);
        LongPointer pointer3 = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), deviceContext);
        long[] reductionShape = Shape.reductionShape(x, normalizeAxis, true, reduceOp.isKeepDims());
        if (y != null) {
            if (x.length() != y.length()) {
                long length = x.length() / x.tensorsAlongDimension(normalizeAxis);
                if (length != y.length()) {
                    throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + length + ", y size = " + y.length());
                }
            } else if (x.tensorsAlongDimension(normalizeAxis) != y.tensorsAlongDimension(normalizeAxis)) {
                throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) + ", dimension = " + Arrays.toString(normalizeAxis) + ")");
            }
        }
        DataType resultType = opContext != null ? reduceOp.resultType(opContext) : reduceOp.resultType();
        if (z == null) {
            INDArray createUninitialized = Nd4j.createUninitialized(resultType, reductionShape);
            setZ(createUninitialized, reduceOp, opContext);
            z = createUninitialized;
        } else if (z.dataType() != resultType || !Arrays.equals(reductionShape, z.shape())) {
            throw new ND4JIllegalStateException("Output array for op " + reduceOp.getClass().getSimpleName() + " should have type " + resultType + " and shape " + Arrays.toString(reductionShape) + " but has datatype " + z.dataType() + " and shape " + Arrays.toString(z.shape()));
        }
        Pointer pointer4 = reduceOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(reduceOp.extraArgsDataBuff((z.dataType() == DataType.BOOL || reduceOp.getOpType() == Op.Type.REDUCE_LONG) ? x.dataType() : z.dataType()), deviceContext) : null;
        Pointer retrieveHostPointer2 = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        Pointer retrieveHostPointer4 = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial(), retrieveHostPointer3, retrieveHostPointer4, retrieveHostPointer, pointer, pointer2});
        Pair tADOnlyShapeInfo = y == null ? null : tadManager.getTADOnlyShapeInfo(y, normalizeAxis);
        Pointer pointer5 = y == null ? null : AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), deviceContext);
        DataBuffer dataBuffer2 = y == null ? null : (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer pointer6 = dataBuffer2 == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer2, deviceContext);
        if (y != null) {
            put.put(12L, pointer5);
            put.put(13L, pointer6);
        }
        LongPointer pointer7 = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), deviceContext);
        OpaqueDataBuffer opaqueDataBuffer = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
        reduceOp.validateDataTypes((OpContext) null);
        if (!z.isScalar()) {
            AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(normalizeAxis), deviceContext);
            if (y != null) {
                nativeOps.execReduce3Tad(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer2, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7, ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, pointer, (LongPointer) pointer2, (LongPointer) pointer5, (LongPointer) pointer6);
            } else if (reduceOp instanceof Variance) {
                nativeOps.execSummaryStatsTad(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7, ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, ((Variance) reduceOp).isBiasCorrected(), pointer, (LongPointer) pointer2);
            } else {
                switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                    case 3:
                        nativeOps.execReduceLong2(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7, ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                        break;
                    case 4:
                        nativeOps.execReduceBool2(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7, ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                        break;
                    case Nd4jCuda.FLOAT32 /* 5 */:
                        nativeOps.execReduceFloat2(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7, ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                        break;
                    case Nd4jCuda.DOUBLE /* 6 */:
                        nativeOps.execReduceSame2(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7, ((BaseCudaDataBuffer) reduceOp.dimensions().data()).getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                        break;
                    default:
                        throw new UnsupportedOperationException();
                }
            }
        } else if (reduceOp instanceof Variance) {
            nativeOps.execSummaryStatsScalar(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7, ((Variance) reduceOp).isBiasCorrected());
        } else if (y != null) {
            nativeOps.execReduce3Scalar(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer2, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7);
        } else {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                case 3:
                    nativeOps.execReduceLong(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7);
                    break;
                case 4:
                    nativeOps.execReduceBool(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7);
                    break;
                case Nd4jCuda.FLOAT32 /* 5 */:
                    nativeOps.execReduceFloat(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7);
                    break;
                case Nd4jCuda.DOUBLE /* 6 */:
                    nativeOps.execReduceSame(put, reduceOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer2, pointer3, pointer4, opaqueDataBuffer3, (LongPointer) retrieveHostPointer4, pointer7);
                    break;
                default:
                    throw new UnsupportedOperationException();
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(reduceOp, opContext, profilingConfigurableHookIn);
        Nd4j.getExecutioner().commit();
        return deviceContext;
    }

    protected CudaContext intercept(ScalarOp scalarOp, int[] iArr) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(scalarOp, new DataBuffer[0]);
        if (iArr != null && iArr.length > 1) {
            Arrays.sort(iArr);
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(scalarOp.opName());
        }
        Pointer retrieveHostPointer = scalarOp.x() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = scalarOp.y() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = scalarOp.z() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.z().shapeInfoDataBuffer());
        LongPointer pointer = AtomicAllocator.getInstance().getPointer(scalarOp.x().shapeInfoDataBuffer(), deviceContext);
        LongPointer pointer2 = AtomicAllocator.getInstance().getPointer(scalarOp.y().shapeInfoDataBuffer(), deviceContext);
        LongPointer pointer3 = AtomicAllocator.getInstance().getPointer(scalarOp.z().shapeInfoDataBuffer(), deviceContext);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(scalarOp.x(), iArr);
        Pointer retrieveHostPointer4 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), deviceContext);
        LongPointer pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), deviceContext);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(scalarOp.z(), iArr);
        LongPointer pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), deviceContext);
        LongPointer pointer7 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), deviceContext);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, retrieveHostPointer4, pointer4, pointer5, pointer6, pointer7});
        Pointer pointer8 = scalarOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(scalarOp.extraArgsDataBuff(scalarOp.z().dataType()), deviceContext) : null;
        AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), deviceContext);
        OpaqueDataBuffer opaqueDataBuffer = scalarOp.x() == null ? null : ((BaseCudaDataBuffer) scalarOp.x().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = scalarOp.y() == null ? null : ((BaseCudaDataBuffer) scalarOp.y().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = scalarOp.z() == null ? null : ((BaseCudaDataBuffer) scalarOp.z().data()).getOpaqueDataBuffer();
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[scalarOp.getOpType().ordinal()]) {
            case Nd4jCuda.INT8 /* 7 */:
                nativeOps.execScalarTad(put, scalarOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer3, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, pointer2, pointer8, ((BaseCudaDataBuffer) scalarOp.dimensions().data()).getOpaqueDataBuffer(), scalarOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, pointer4, pointer5, pointer6, pointer7);
                break;
            case Nd4jCuda.INT16 /* 8 */:
                nativeOps.execScalarBoolTad(put, scalarOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer3, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, pointer2, pointer8, ((BaseCudaDataBuffer) scalarOp.dimensions().data()).getOpaqueDataBuffer(), scalarOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, pointer4, pointer5, pointer6, pointer7);
                break;
            default:
                throw new UnsupportedOperationException();
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(scalarOp, null, profilingConfigurableHookIn);
        return null;
    }

    public INDArray exec(ScalarOp scalarOp) {
        invoke(scalarOp, (OpContext) null);
        return scalarOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(ScalarOp scalarOp, OpContext opContext) {
        Pointer pointer;
        long profilingConfigurableHookIn = profilingConfigurableHookIn(scalarOp, new DataBuffer[0]);
        checkForCompression(scalarOp);
        INDArray x = getX(scalarOp, opContext);
        getY(scalarOp, opContext);
        INDArray z = getZ(scalarOp, opContext);
        if (z == null) {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[scalarOp.getOpType().ordinal()]) {
                case Nd4jCuda.INT8 /* 7 */:
                    z = x.ulike();
                    setZ(x.ulike(), scalarOp, opContext);
                    break;
                case Nd4jCuda.INT16 /* 8 */:
                    z = Nd4j.createUninitialized(DataType.BOOL, x.shape());
                    setZ(z, scalarOp, opContext);
                    break;
                default:
                    throw new ND4JIllegalStateException("Unknown op type: [" + scalarOp.getOpType() + "]");
            }
        }
        if (x.length() != z.length()) {
            throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(x.shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(z.shapeInfoDataBuffer().asInt()) + "]");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(scalarOp.opName());
        }
        if (scalarOp.dimensions() != null) {
            intercept(scalarOp, scalarOp.dimensions().toIntVector());
            return null;
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        Pointer retrieveHostPointer = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = scalarOp.scalar() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.scalar().shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        LongPointer pointer2 = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), deviceContext);
        if (scalarOp.extraArgs() != null) {
            pointer = AtomicAllocator.getInstance().getPointer(scalarOp.extraArgsDataBuff(scalarOp.getOpType() == Op.Type.SCALAR_BOOL ? x.dataType() : z.dataType()), deviceContext);
        } else {
            pointer = null;
        }
        Pointer pointer3 = pointer;
        LongPointer pointer4 = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), deviceContext);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial(), retrieveHostPointer2, retrieveHostPointer3, null, null});
        OpaqueDataBuffer opaqueDataBuffer = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = scalarOp.scalar() == null ? null : ((BaseCudaDataBuffer) scalarOp.scalar().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[scalarOp.getOpType().ordinal()]) {
            case Nd4jCuda.INT8 /* 7 */:
                nativeOps.execScalar(put, scalarOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(scalarOp.scalar().shapeInfoDataBuffer(), deviceContext), pointer3);
                break;
            case Nd4jCuda.INT16 /* 8 */:
                nativeOps.execScalarBool(put, scalarOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(scalarOp.scalar().shapeInfoDataBuffer(), deviceContext), pointer3);
                break;
            default:
                throw new UnsupportedOperationException("Unknown op type: " + scalarOp.getOpType());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(scalarOp, opContext, profilingConfigurableHookIn);
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(TransformOp transformOp, OpContext opContext) {
        Pointer pointer;
        long profilingConfigurableHookIn = profilingConfigurableHookIn(transformOp, new DataBuffer[0]);
        INDArray x = getX(transformOp, opContext);
        INDArray y = getY(transformOp, opContext);
        INDArray z = getZ(transformOp, opContext);
        checkForCompression(transformOp);
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext deviceContext = atomicAllocator.getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(transformOp.opName());
        }
        INDArray iNDArray = null;
        LongPointer pointer2 = atomicAllocator.getPointer(x.shapeInfoDataBuffer(), deviceContext);
        Object[] objArr = null;
        Pointer retrieveHostPointer = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        if (z == null) {
            iNDArray = Nd4j.createUninitialized(transformOp.resultType(), x.shape(), x.ordering());
            setZ(iNDArray, transformOp, opContext);
            z = iNDArray;
        }
        if (transformOp.extraArgs() != null) {
            pointer = atomicAllocator.getPointer(transformOp.extraArgsDataBuff((transformOp.getOpType() == Op.Type.TRANSFORM_BOOL || transformOp.getOpType() == Op.Type.PAIRWISE_BOOL) ? x.dataType() : z.dataType()), deviceContext);
        } else {
            pointer = null;
        }
        Pointer pointer3 = pointer;
        Pointer retrieveHostPointer3 = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        transformOp.validateDataTypes(opContext, this.experimentalMode.get());
        LongPointer pointer4 = atomicAllocator.getPointer(z.shapeInfoDataBuffer(), deviceContext);
        PointerPointer pointerPointer = this.extraz.get();
        Pointer[] pointerArr = new Pointer[20];
        pointerArr[0] = AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        pointerArr[1] = deviceContext.getOldStream();
        pointerArr[2] = atomicAllocator.getDeviceIdPointer();
        pointerArr[3] = deviceContext.getBufferAllocation();
        pointerArr[4] = deviceContext.getBufferReduction();
        pointerArr[5] = deviceContext.getBufferScalar();
        pointerArr[6] = deviceContext.getBufferSpecial();
        pointerArr[7] = retrieveHostPointer2;
        pointerArr[8] = retrieveHostPointer3;
        pointerArr[9] = null;
        pointerArr[10] = null;
        pointerArr[11] = null;
        pointerArr[12] = null;
        pointerArr[13] = null;
        pointerArr[14] = null;
        pointerArr[15] = null;
        pointerArr[16] = null;
        pointerArr[17] = null;
        pointerArr[18] = new CudaPointer(0 == 0 ? 0L : objArr.length);
        pointerArr[19] = null;
        PointerPointer put = pointerPointer.put(pointerArr);
        OpaqueDataBuffer opaqueDataBuffer = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
        if (y == null) {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[transformOp.getOpType().ordinal()]) {
                case Nd4jCuda.INT32 /* 9 */:
                    nativeOps.execTransformBool(put, transformOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, pointer3);
                    break;
                case 10:
                default:
                    throw new UnsupportedOperationException();
                case Nd4jCuda.UINT8 /* 11 */:
                    nativeOps.execTransformAny(put, transformOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, pointer3);
                    break;
                case Nd4jCuda.UINT16 /* 12 */:
                    nativeOps.execTransformFloat(put, transformOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, pointer3);
                    break;
                case Nd4jCuda.UINT32 /* 13 */:
                    nativeOps.execTransformSame(put, transformOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, pointer3);
                    break;
                case Nd4jCuda.UINT64 /* 14 */:
                    nativeOps.execTransformStrict(put, transformOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, pointer3);
                    break;
            }
        } else {
            LongPointer pointer5 = atomicAllocator.getPointer(y.shapeInfoDataBuffer(), deviceContext);
            if (x.length() != y.length() || x.length() != z.length()) {
                throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform");
            }
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[transformOp.getOpType().ordinal()]) {
                case Nd4jCuda.INT32 /* 9 */:
                case 10:
                    nativeOps.execPairwiseTransformBool(put, transformOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, pointer5, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, pointer3);
                    break;
                default:
                    nativeOps.execPairwiseTransform(put, transformOp.opNum(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, pointer2, opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, pointer5, opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, pointer4, pointer3);
                    break;
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (pointer3 != null) {
            pointer3.address();
        }
        if (iNDArray != null) {
            iNDArray.elementWiseStride();
        }
        profilingConfigurableHookOut(transformOp, opContext, profilingConfigurableHookIn);
        return null;
    }

    protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
        DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4, false);
        batch.setParamsSurface(createInt);
        return createInt;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    public void exec(List<Aggregate> list) {
        if (list.size() == 0) {
            return;
        }
        Iterator it = Batch.getBatches(list, 8192).iterator();
        while (it.hasNext()) {
            exec((Batch) it.next());
        }
        AtomicAllocator.getInstance().getDeviceContext().syncOldStream();
    }

    public void exec(Aggregate aggregate) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    public INDArray exec(RandomOp randomOp) {
        return exec(randomOp, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp randomOp, Random random) {
        return exec(randomOp, null, random);
    }

    public INDArray exec(RandomOp randomOp, OpContext opContext, Random random) {
        INDArray x = getX(randomOp, opContext);
        INDArray y = getY(randomOp, opContext);
        INDArray z = getZ(randomOp, opContext);
        if ((randomOp instanceof BaseRandomOp) && ((BaseRandomOp) randomOp).isTripleArgRngOp() && z != null && x == null && y == null) {
            x = z;
            y = z;
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(randomOp, new DataBuffer[0]);
        checkForCompression(randomOp);
        if (random.getStatePointer() == null) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(randomOp.opName());
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()), deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        Pointer retrieveHostPointer = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        Pointer retrieveHostPointer3 = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        OpaqueDataBuffer opaqueDataBuffer = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();
        if (x != null && y != null && z != null) {
            nativeOps.execRandom3(put, randomOp.opNum(), random.getStatePointer(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer2, (LongPointer) retrieveHostPointer2, AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), deviceContext), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(z.dataType()), deviceContext));
        } else if (x == null || z == null) {
            nativeOps.execRandom(put, randomOp.opNum(), random.getStatePointer(), opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), deviceContext), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(z.dataType()), deviceContext));
        } else {
            nativeOps.execRandom2(put, randomOp.opNum(), random.getStatePointer(), opaqueDataBuffer, (LongPointer) retrieveHostPointer, AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), deviceContext), opaqueDataBuffer3, (LongPointer) retrieveHostPointer3, AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), deviceContext), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(z.dataType()), deviceContext));
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        profilingConfigurableHookOut(randomOp, opContext, profilingConfigurableHookIn);
        return z;
    }

    public synchronized Properties getEnvironmentInformation() {
        if (this.properties == null) {
            Properties environmentInformation = super.getEnvironmentInformation();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < nativeOps.getAvailableDevices(); i++) {
                HashMap hashMap = new HashMap();
                hashMap.put("cuda.deviceName", nativeOps.getDeviceName(i));
                hashMap.put("cuda.freeMemory", Long.valueOf(nativeOps.getDeviceFreeMemory(i)));
                hashMap.put("cuda.totalMemory", Long.valueOf(nativeOps.getDeviceTotalMemory(i)));
                hashMap.put("cuda.deviceMajor", Long.valueOf(nativeOps.getDeviceMajor(i)));
                hashMap.put("cuda.deviceMinor", Long.valueOf(nativeOps.getDeviceMinor(i)));
                arrayList.add(i, hashMap);
            }
            environmentInformation.put("backend", "CUDA");
            environmentInformation.put("cuda.availableDevices", Integer.valueOf(nativeOps.getAvailableDevices()));
            environmentInformation.put("cuda.devicesInformation", arrayList);
            environmentInformation.put("blas.vendor", Nd4j.factory().blas().getBlasVendor().toString());
            environmentInformation.put("memory.free", Long.valueOf(Pointer.maxBytes() - Pointer.totalBytes()));
            environmentInformation.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
            this.properties = environmentInformation;
        } else {
            List list = (List) this.properties.get("cuda.devicesInformation");
            for (int i2 = 0; i2 < nativeOps.getAvailableDevices(); i2++) {
                Map map = (Map) list.get(i2);
                map.put("cuda.freeMemory", Long.valueOf(nativeOps.getDeviceFreeMemory(i2)));
                map.put("cuda.totalMemory", Long.valueOf(nativeOps.getDeviceTotalMemory(i2)));
            }
            this.properties.put("cuda.devicesInformation", list);
            this.properties.put("memory.free", Long.valueOf(Pointer.maxBytes() - Pointer.totalBytes()));
            this.properties.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
        }
        return this.properties;
    }

    public TADManager getTADManager() {
        return tadManager;
    }

    public void printEnvironmentInformation() {
        super.printEnvironmentInformation();
    }

    public void commit() {
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        deviceContext.syncOldStream();
        deviceContext.syncSpecialStream();
    }

    public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
        if (this.customOps == null) {
            String allCustomOps = nativeOps.getAllCustomOps();
            if (allCustomOps == null || allCustomOps.isEmpty()) {
                log.warn("No customs ops available!");
                this.customOps = Collections.emptyMap();
                return this.customOps;
            }
            HashMap hashMap = new HashMap();
            for (String str : allCustomOps.split(";")) {
                if (str != null && !str.isEmpty()) {
                    String[] split = str.split(":");
                    hashMap.put(split[0], CustomOpDescriptor.builder().hash(Long.valueOf(split[1]).longValue()).numInputs(Integer.valueOf(split[2]).intValue()).numOutputs(Integer.valueOf(split[3]).intValue()).allowsInplace(Integer.valueOf(split[4]).intValue() == 1).numTArgs(Integer.valueOf(split[5]).intValue()).numIArgs(Integer.valueOf(split[6]).intValue()).build());
                }
            }
            this.customOps = Collections.unmodifiableMap(hashMap);
        }
        return this.customOps;
    }

    protected LongShapeDescriptor getShapeFromPointer(LongPointer longPointer) {
        long[] jArr = new long[(((int) longPointer.get(0L)) * 2) + 4];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = longPointer.get(i);
        }
        return LongShapeDescriptor.fromShape(Shape.shape(jArr), Shape.stride(jArr), Shape.elementWiseStride(jArr), Shape.order(jArr), ArrayOptionsHelper.dataType(jArr), ArrayOptionsHelper.arrayType(jArr) == ArrayType.EMPTY);
    }

    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp customOp) {
        if (customOp == null) {
            throw new NullPointerException("op is marked non-null but is null");
        }
        return calculateOutputShape(customOp, null);
    }

    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp customOp, OpContext opContext) {
        if (customOp == null) {
            throw new NullPointerException("op is marked non-null but is null");
        }
        Nd4j.getExecutioner().commit();
        customOp.opName().toLowerCase();
        long opHash = customOp.opHash();
        ArrayList arrayList = new ArrayList();
        int numInputArguments = opContext != null ? opContext.numInputArguments() : customOp.numInputArguments();
        if (numInputArguments == 0 && customOp.getDescriptor().getNumInputs() >= 1) {
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: number of input args was 0", customOp.getClass().getName());
            }
            return Collections.emptyList();
        }
        PointerPointer pointerPointer = new PointerPointer(numInputArguments * 2);
        PointerPointer pointerPointer2 = new PointerPointer(numInputArguments);
        int i = 0;
        for (INDArray iNDArray : opContext != null ? opContext.getInputArrays() : customOp.inputArguments()) {
            AffinityManager.Location activeLocation = Nd4j.getAffinityManager().getActiveLocation(iNDArray);
            if (activeLocation != AffinityManager.Location.DEVICE && activeLocation != AffinityManager.Location.EVERYWHERE) {
                Nd4j.getAffinityManager().ensureLocation(iNDArray, AffinityManager.Location.DEVICE);
            }
            if (!iNDArray.isEmpty()) {
                pointerPointer.put(i, iNDArray.data().addressPointer());
                pointerPointer.put(i + numInputArguments, AtomicAllocator.getInstance().getPointer(iNDArray.data()));
            }
            int i2 = i;
            i++;
            pointerPointer2.put(i2, iNDArray.shapeInfoDataBuffer().addressPointer());
        }
        int numIArguments = opContext != null ? opContext.numIArguments() : customOp.numIArguments();
        LongPointer longPointer = numIArguments > 0 ? new LongPointer(numIArguments) : null;
        int i3 = 0;
        if (opContext != null) {
            Iterator it = opContext.getIArguments().iterator();
            while (it.hasNext()) {
                int i4 = i3;
                i3++;
                longPointer.put(i4, ((Long) it.next()).longValue());
            }
        } else {
            for (long j : customOp.iArgs()) {
                int i5 = i3;
                i3++;
                longPointer.put(i5, j);
            }
        }
        int numTArguments = opContext != null ? opContext.numTArguments() : customOp.numTArguments();
        DoublePointer doublePointer = numTArguments > 0 ? new DoublePointer(numTArguments) : null;
        int numBArguments = opContext != null ? opContext.numBArguments() : customOp.numBArguments();
        BooleanPointer booleanPointer = numBArguments > 0 ? new BooleanPointer(numBArguments) : null;
        int numDArguments = opContext != null ? opContext.numDArguments() : customOp.numDArguments();
        IntPointer intPointer = numDArguments > 0 ? new IntPointer(numDArguments) : null;
        int i6 = 0;
        if (opContext != null) {
            Iterator it2 = opContext.getBArguments().iterator();
            while (it2.hasNext()) {
                int i7 = i6;
                i6++;
                booleanPointer.put(i7, ((Boolean) it2.next()).booleanValue());
            }
        } else {
            for (boolean z : customOp.bArgs()) {
                int i8 = i6;
                i6++;
                booleanPointer.put(i8, z);
            }
        }
        int i9 = 0;
        if (opContext != null) {
            Iterator it3 = opContext.getTArguments().iterator();
            while (it3.hasNext()) {
                int i10 = i9;
                i9++;
                doublePointer.put(i10, ((Double) it3.next()).doubleValue());
            }
        } else {
            for (double d : customOp.tArgs()) {
                int i11 = i9;
                i9++;
                doublePointer.put(i11, d);
            }
        }
        int i12 = 0;
        if (opContext != null) {
            Iterator it4 = opContext.getDArguments().iterator();
            while (it4.hasNext()) {
                int i13 = i12;
                i12++;
                intPointer.put(i13, ((DataType) it4.next()).toInt());
            }
        } else {
            for (DataType dataType : customOp.dArgs()) {
                int i14 = i12;
                i12++;
                intPointer.put(i14, dataType.toInt());
            }
        }
        OpaqueShapeList calculateOutputShapes2 = nativeOps.calculateOutputShapes2((PointerPointer) null, opHash, pointerPointer, pointerPointer2, numInputArguments, doublePointer, numTArguments, longPointer, numIArguments, booleanPointer, numBArguments, intPointer, numDArguments);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (calculateOutputShapes2 == null) {
            throw new RuntimeException();
        }
        for (int i15 = 0; i15 < nativeOps.getShapeListSize(calculateOutputShapes2); i15++) {
            arrayList.add(getShapeFromPointer(new PagedPointer(nativeOps.getShape(calculateOutputShapes2, i15)).asLongPointer()));
        }
        nativeOps.deleteShapeList(calculateOutputShapes2);
        return arrayList;
    }

    public INDArray[] exec(CustomOp customOp) {
        Nd4j.getExecutioner().commit();
        boolean z = false;
        if (customOp.numOutputArguments() == 0 && !customOp.isInplaceCall()) {
            try {
                List<LongShapeDescriptor> calculateOutputShape = calculateOutputShape(customOp);
                if (calculateOutputShape.isEmpty()) {
                    throw new ND4JIllegalStateException("Op name " + customOp.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
                }
                Iterator<LongShapeDescriptor> it = calculateOutputShape.iterator();
                while (it.hasNext()) {
                    customOp.addOutputArgument(new INDArray[]{Nd4j.create(it.next(), false)});
                }
                z = true;
            } catch (Exception e) {
                throw new ND4JIllegalStateException("Op name " + customOp.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e);
            }
        }
        AtomicAllocator.getInstance().getDeviceContext();
        String opName = customOp.opName();
        try {
            CudaOpContext cudaOpContext = (CudaOpContext) buildContext();
            Throwable th = null;
            if (z) {
                try {
                    try {
                        cudaOpContext.shapeFunctionOverride(true);
                    } finally {
                    }
                } finally {
                }
            }
            cudaOpContext.markInplace(customOp.isInplaceCall());
            cudaOpContext.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
            cudaOpContext.setInputArrays(customOp.inputArguments());
            cudaOpContext.setOutputArrays(customOp.outputArguments());
            cudaOpContext.setBArguments(customOp.bArgs());
            cudaOpContext.setIArguments(customOp.iArgs());
            cudaOpContext.setTArguments(customOp.tArgs());
            cudaOpContext.setDArguments(customOp.dArgs());
            INDArray[] exec = exec(customOp, cudaOpContext);
            Pair<Long, Long> rngStates = cudaOpContext.getRngStates();
            for (INDArray iNDArray : customOp.inputArguments()) {
                if (!iNDArray.isEmpty()) {
                    ((BaseCudaDataBuffer) iNDArray.data()).actualizePointerAndIndexer();
                }
            }
            for (INDArray iNDArray2 : customOp.outputArguments()) {
                if (!iNDArray2.isEmpty()) {
                    ((BaseCudaDataBuffer) iNDArray2.data()).actualizePointerAndIndexer();
                }
            }
            Nd4j.getRandom().setStates(((Long) rngStates.getFirst()).longValue(), ((Long) rngStates.getSecond()).longValue());
            if (cudaOpContext != null) {
                if (0 != 0) {
                    try {
                        cudaOpContext.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    cudaOpContext.close();
                }
            }
            return exec;
        } catch (ND4JOpProfilerException e2) {
            throw e2;
        } catch (Exception e3) {
            throw new RuntimeException("Op [" + opName + "] execution failed", e3);
        }
    }

    public void enableDebugMode(boolean z) {
        this.debug.set(z);
        nativeOps.enableDebugMode(z);
    }

    public void enableVerboseMode(boolean z) {
        this.verbose.set(z);
        nativeOps.enableVerboseMode(z);
    }

    public void registerGraph(long j, Pointer pointer) {
        nativeOps.registerGraph((PointerPointer) null, j, pointer);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public Map<String, INDArray> executeGraph(long j, @NonNull Map<String, INDArray> map, @NonNull Map<String, Integer> map2) {
        if (map == null) {
            throw new NullPointerException("map is marked non-null but is null");
        }
        if (map2 == null) {
            throw new NullPointerException("reverseMap is marked non-null but is null");
        }
        Nd4j.getExecutioner().commit();
        PointerPointer pointerPointer = new PointerPointer(map.size() * 2);
        PointerPointer pointerPointer2 = new PointerPointer(map.size() * 2);
        IntPointer intPointer = new IntPointer(map.size());
        int i = 0;
        Iterator it = new ArrayList(map.keySet()).iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            INDArray iNDArray = map.get(str);
            pointerPointer.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArray));
            pointerPointer2.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()));
            intPointer.put(i, map2.get(str).intValue());
            i++;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        OpaqueVariablesSet executeStoredGraph = nativeOps.executeStoredGraph((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpStatus byNumber = OpStatus.byNumber(nativeOps.getVariablesSetStatus(executeStoredGraph));
        if (byNumber != OpStatus.ND4J_STATUS_OK) {
            throw new ND4JIllegalStateException("Op execution failed: " + byNumber);
        }
        if (0 >= nativeOps.getVariablesSetSize(executeStoredGraph)) {
            if (nativeOps.lastErrorCode() != 0) {
                throw new RuntimeException(nativeOps.lastErrorMessage());
            }
            nativeOps.deleteVariablesSet(executeStoredGraph);
            return linkedHashMap;
        }
        OpaqueVariable variable = nativeOps.getVariable(executeStoredGraph, 0);
        nativeOps.getVariableId(variable);
        nativeOps.getVariableIndex(variable);
        LongPointer variableShape = nativeOps.getVariableShape(variable);
        Pointer variableBuffer = nativeOps.getVariableBuffer(variable);
        long[] jArr = new long[(((int) variableShape.get(0L)) * 2) + 4];
        for (int i2 = 0; i2 < jArr.length; i2++) {
            jArr[i2] = variableShape.get(i2);
        }
        Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(Nd4j.create(Shape.shapeOf(jArr), Shape.stridesOf(jArr), 0L, Shape.order(jArr))), variableBuffer, ArrayUtil.prod(r0) * Nd4j.sizeOfDataType());
        throw new UnsupportedOperationException("Pew-pew");
    }

    public void forgetGraph(long j) {
        nativeOps.unregisterGraph((PointerPointer) null, j);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public void setElementsThreshold(int i) {
        nativeOps.setElementThreshold(i);
    }

    public void setTadThreshold(int i) {
        nativeOps.setTADThreshold(i);
    }

    public OpExecutioner.ExecutionerType type() {
        return OpExecutioner.ExecutionerType.CUDA;
    }

    public String getString(DataBuffer dataBuffer, long j) {
        Preconditions.checkArgument(dataBuffer instanceof CudaUtf8Buffer, "Expected Utf8Buffer");
        return new Nd4jCuda.utf8string((Pointer) new PagedPointer(dataBuffer.indexer().get(j)))._buffer().capacity(r0._length()).getString();
    }

    public boolean isExperimentalMode() {
        return this.experimentalMode.get();
    }

    public void scatterUpdate(ScatterUpdate.UpdateOp updateOp, @NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, @NonNull INDArray iNDArray3, @NonNull int[] iArr) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("indices is marked non-null but is null");
        }
        if (iNDArray3 == null) {
            throw new NullPointerException("updates is marked non-null but is null");
        }
        if (iArr == null) {
            throw new NullPointerException("axis is marked non-null but is null");
        }
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(iNDArray, iArr);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(iNDArray3, iArr);
        if (((DataBuffer) tADOnlyShapeInfo2.getSecond()).length() != iNDArray2.length()) {
            throw new IllegalStateException("Number of updates doesn't match number of indices. Bad dimensions used?");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        nativeOps.scatterUpdate(this.extraz.get().put(new Pointer[]{null, deviceContext.getOldStream()}), updateOp.ordinal(), (int) iNDArray2.length(), (Pointer) null, AtomicAllocator.getInstance().getHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst()), (LongPointer) null, AtomicAllocator.getInstance().getPointer(iNDArray, deviceContext), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst()), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond()), (Pointer) null, AtomicAllocator.getInstance().getHostPointer((DataBuffer) tADOnlyShapeInfo2.getFirst()), (LongPointer) null, AtomicAllocator.getInstance().getPointer(iNDArray3, deviceContext), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst()), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond()), AtomicAllocator.getInstance().getHostPointer(iNDArray2), AtomicAllocator.getInstance().getHostPointer(iNDArray2.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(iNDArray2, deviceContext), AtomicAllocator.getInstance().getPointer(iNDArray2.shapeInfoDataBuffer(), deviceContext));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public OpContext buildContext() {
        return new CudaOpContext();
    }

    public INDArray[] exec(CustomOp customOp, OpContext opContext) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(customOp, opContext);
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        ((CudaOpContext) opContext).setCudaStream(deviceContext.getOldStream(), deviceContext.getBufferReduction(), deviceContext.getBufferAllocation());
        int execCustomOp2 = nativeOps.execCustomOp2((PointerPointer) null, customOp.opHash(), opContext.contextPointer());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (execCustomOp2 != 0) {
            throw new RuntimeException("Op [" + customOp.opName() + "] execution failed");
        }
        profilingConfigurableHookOut(customOp, opContext, profilingConfigurableHookIn);
        return opContext.getOutputArrays().isEmpty() ? new INDArray[0] : (INDArray[]) opContext.getOutputArrays().toArray(new INDArray[opContext.getOutputArrays().size()]);
    }

    public INDArrayStatistics inspectArray(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        Nd4jCuda.DebugInfo debugInfo = new Nd4jCuda.DebugInfo();
        CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
        AtomicAllocator.getInstance().synchronizeHostData(iNDArray);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        nativeOps.inspectArray(this.extraz.get().put(new Pointer[]{null, deviceContext.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), deviceContext.getBufferAllocation(), deviceContext.getBufferReduction(), deviceContext.getBufferScalar(), deviceContext.getBufferSpecial()}), AtomicAllocator.getInstance().getHostPointer(iNDArray), AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(iNDArray, deviceContext), AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer()), debugInfo);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        return INDArrayStatistics.builder().minValue(debugInfo._minValue()).maxValue(debugInfo._maxValue()).meanValue(debugInfo._meanValue()).stdDevValue(debugInfo._stdDevValue()).countInf(debugInfo._infCount()).countNaN(debugInfo._nanCount()).countNegative(debugInfo._negativeCount()).countPositive(debugInfo._positiveCount()).countZero(debugInfo._zeroCount()).build();
    }

    public DataBuffer createShapeInfo(long[] jArr, long[] jArr2, long j, char c, DataType dataType, boolean z) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantShapeBuffer shapeBuffer = nativeOps.shapeBuffer(jArr.length, new LongPointer(jArr), new LongPointer(jArr2), dataType.toInt(), c, j, z);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer cudaLongDataBuffer = new CudaLongDataBuffer(nativeOps.getConstantShapeBufferPrimary(shapeBuffer), nativeOps.getConstantShapeBufferSpecial(shapeBuffer), Shape.shapeInfoLength(jArr.length));
        nativeOps.deleteConstantShapeBuffer(shapeBuffer);
        return cudaLongDataBuffer;
    }

    public DataBuffer createShapeInfo(long[] jArr, long[] jArr2, long j, char c, DataType dataType, long j2) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantShapeBuffer shapeBufferEx = nativeOps.shapeBufferEx(jArr.length, new LongPointer(jArr), new LongPointer(jArr2), dataType.toInt(), c, j, j2);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer cudaLongDataBuffer = new CudaLongDataBuffer(nativeOps.getConstantShapeBufferPrimary(shapeBufferEx), nativeOps.getConstantShapeBufferSpecial(shapeBufferEx), Shape.shapeInfoLength(jArr.length));
        nativeOps.deleteConstantShapeBuffer(shapeBufferEx);
        return cudaLongDataBuffer;
    }

    public TadPack tadShapeInfoAndOffsets(INDArray iNDArray, int[] iArr) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueTadPack tadOnlyShapeInfo = nativeOps.tadOnlyShapeInfo(iNDArray.shapeInfoDataBuffer().addressPointer(), new IntPointer(iArr), iArr.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer cudaLongDataBuffer = new CudaLongDataBuffer((Pointer) nativeOps.getPrimaryShapeInfo(tadOnlyShapeInfo), (Pointer) nativeOps.getSpecialShapeInfo(tadOnlyShapeInfo), nativeOps.getShapeInfoLength(tadOnlyShapeInfo));
        CudaLongDataBuffer cudaLongDataBuffer2 = new CudaLongDataBuffer((Pointer) nativeOps.getPrimaryOffsets(tadOnlyShapeInfo), (Pointer) nativeOps.getSpecialOffsets(tadOnlyShapeInfo), nativeOps.getNumberOfTads(tadOnlyShapeInfo));
        nativeOps.deleteTadPack(tadOnlyShapeInfo);
        return new TadPack(cudaLongDataBuffer, cudaLongDataBuffer2);
    }

    public DataBuffer createConstantBuffer(long[] jArr, DataType dataType) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer constantBufferLong = nativeOps.constantBufferLong(dataType.toInt(), new LongPointer(jArr), jArr.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        DataBuffer createBuffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(constantBufferLong), nativeOps.getConstantDataBufferSpecial(constantBufferLong), jArr.length, dataType);
        createBuffer.setConstant(true);
        return createBuffer;
    }

    public DataBuffer createConstantBuffer(double[] dArr, DataType dataType) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer constantBufferDouble = nativeOps.constantBufferDouble(dataType.toInt(), new DoublePointer(dArr), dArr.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        DataBuffer createBuffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(constantBufferDouble), nativeOps.getConstantDataBufferSpecial(constantBufferDouble), dArr.length, dataType);
        createBuffer.setConstant(true);
        return createBuffer;
    }

    public String runLightBenchmarkSuit(boolean z) {
        return nativeOps.runLightBenchmarkSuit(z);
    }

    public String runFullBenchmarkSuit(boolean z) {
        return nativeOps.runFullBenchmarkSuit(z);
    }
}
