package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCuda;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.class */
public class CudaExecutioner extends DefaultOpExecutioner {
    private static final Logger log = LoggerFactory.getLogger(CudaExecutioner.class);
    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected static TADManager tadManager = new DeviceTADManager();
    protected volatile transient Properties properties;
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal<>();
    protected ThreadLocal<String> lastOp = new ThreadLocal<>();
    protected Map<String, CustomOpDescriptor> customOps = null;

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public String getLastOp() {
        return this.lastOp.get();
    }

    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        long profilingHookIn = profilingHookIn(broadcastOp);
        checkForCompression(broadcastOp);
        validateDataType(Nd4j.dataType(), broadcastOp);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= broadcastOp.x().rank() && iArr[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(iArr) + " contains element that higher then rank of op.X: [" + broadcastOp.x().rank() + "]");
            }
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(broadcastOp.opName());
        }
        Pointer retrieveHostPointer = broadcastOp.y() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = broadcastOp.z() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.z().shapeInfoDataBuffer());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(broadcastOp.x(), prepareAction);
        DoublePointer pointer2 = AtomicAllocator.getInstance().getPointer(broadcastOp.y(), prepareAction);
        DoublePointer pointer3 = AtomicAllocator.getInstance().getPointer(broadcastOp.z(), prepareAction);
        IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), prepareAction);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        Pointer pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(broadcastOp.z(), iArr);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(broadcastOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer5, pointer6, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction)});
        IntPointer pointer7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(put, broadcastOp.opNum(), pointer, pointer4, pointer2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction), pointer3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction), pointer7, iArr.length);
        } else if (broadcastOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execBroadcastFloat(put, broadcastOp.opNum(), (FloatPointer) pointer, pointer4, (FloatPointer) pointer2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction), (FloatPointer) pointer3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction), pointer7, iArr.length);
        } else {
            nativeOps.execBroadcastHalf(put, broadcastOp.opNum(), (ShortPointer) pointer, pointer4, (ShortPointer) pointer2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction), (ShortPointer) pointer3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction), pointer7, iArr.length);
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        profilingHookOut(broadcastOp, profilingHookIn);
        return broadcastOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray naiveExec(Accumulation accumulation, int... iArr) {
        long profilingHookIn = profilingHookIn(accumulation);
        INDArray z = accumulation.z();
        validateDataType(Nd4j.dataType(), accumulation);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= accumulation.x().rank() && iArr[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(iArr) + " contains element that higher then rank of op.X: [" + accumulation.x().rank() + "]");
            }
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(accumulation.z(), accumulation.x(), accumulation.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(accumulation.opName());
        }
        Pointer retrieveHostPointer = accumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(accumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = accumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(accumulation.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(accumulation.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        IntPointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer pointer2 = dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction);
        DoublePointer pointer3 = AtomicAllocator.getInstance().getPointer(accumulation.x(), prepareAction);
        IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(accumulation.x().shapeInfoDataBuffer(), prepareAction);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(accumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer, pointer2});
        Pointer pointer5 = null;
        Pointer pointer6 = null;
        if (accumulation.y() != null) {
            if (!(iArr.length == 1 && iArr[0] == Integer.MAX_VALUE) && accumulation.x().tensorAlongDimension(0, iArr).lengthLong() == accumulation.y().lengthLong()) {
                DataBuffer constantBuffer = Nd4j.getConstantHandler().getConstantBuffer(new int[]{0, 0});
                pointer5 = constantBuffer == null ? null : AtomicAllocator.getInstance().getPointer(constantBuffer, prepareAction);
                pointer6 = AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction);
                put.put(12L, AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction));
                put.put(13L, (Pointer) null);
            } else {
                if (!accumulation.isComplexAccumulation() && accumulation.x().lengthLong() != accumulation.y().lengthLong()) {
                    throw new ND4JIllegalStateException("Op.X [" + accumulation.x().lengthLong() + "] and Op.Y [" + accumulation.y().lengthLong() + "] lengths should match");
                }
                Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(accumulation.y(), iArr);
                pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction);
                DataBuffer dataBuffer2 = (DataBuffer) tADOnlyShapeInfo2.getSecond();
                pointer5 = dataBuffer2 == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer2, prepareAction);
                put.put(12L, pointer6);
                put.put(13L, pointer5);
            }
        }
        Pointer pointer7 = accumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(accumulation.extraArgsDataBuff(), prepareAction) : null;
        IntPointer pointer8 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                if (z.isScalar()) {
                    double execSummaryStatsScalarDouble = nativeOps.execSummaryStatsScalarDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer7, ((Variance) accumulation).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                    z.assign(Double.valueOf(execSummaryStatsScalarDouble));
                    accumulation.setFinalResult(execSummaryStatsScalarDouble);
                } else {
                    nativeOps.execSummaryStatsDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length, ((Variance) accumulation).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (accumulation.y() != null) {
                if (accumulation.isComplexAccumulation()) {
                    nativeOps.execReduce3AllDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length, pointer, new LongPointerWrapper(pointer2), (IntPointer) pointer6, new LongPointerWrapper(pointer5));
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                } else if (z.isScalar()) {
                    double execReduce3ScalarDouble = nativeOps.execReduce3ScalarDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction));
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                    z.assign(Double.valueOf(execReduce3ScalarDouble));
                    accumulation.setFinalResult(execReduce3ScalarDouble);
                } else {
                    nativeOps.execReduce3Double(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length);
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (z.isScalar()) {
                double execReduceScalarDouble = nativeOps.execReduceScalarDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer7);
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                z.assign(Double.valueOf(execReduceScalarDouble));
                accumulation.setFinalResult(execReduceScalarDouble);
            } else {
                nativeOps.execReduceDouble(put, accumulation.opNum(), pointer3, pointer4, (DoublePointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length);
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (accumulation instanceof Variance) {
                if (z.isScalar()) {
                    float execSummaryStatsScalarFloat = nativeOps.execSummaryStatsScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer7, ((Variance) accumulation).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                    z.assign(Float.valueOf(execSummaryStatsScalarFloat));
                    accumulation.setFinalResult(execSummaryStatsScalarFloat);
                } else {
                    nativeOps.execSummaryStatsFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length, ((Variance) accumulation).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (accumulation.y() != null) {
                if (accumulation.isComplexAccumulation()) {
                    nativeOps.execReduce3AllFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length, pointer, new LongPointerWrapper(pointer2), (IntPointer) pointer6, new LongPointerWrapper(pointer5));
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                } else if (z.isScalar()) {
                    float execReduce3ScalarFloat = nativeOps.execReduce3ScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction));
                    z.assign(Float.valueOf(execReduce3ScalarFloat));
                    accumulation.setFinalResult(execReduce3ScalarFloat);
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                } else {
                    nativeOps.execReduce3Float(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length);
                    AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (z.isScalar()) {
                float execReduceScalarFloat = nativeOps.execReduceScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer7);
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                z.assign(Float.valueOf(execReduceScalarFloat));
                accumulation.setFinalResult(execReduceScalarFloat);
            } else {
                nativeOps.execReduceFloat(put, accumulation.opNum(), (FloatPointer) pointer3, pointer4, (FloatPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length);
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (accumulation instanceof Variance) {
            if (z.isScalar()) {
                float execSummaryStatsScalarHalf = nativeOps.execSummaryStatsScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer7, ((Variance) accumulation).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                z.assign(Float.valueOf(execSummaryStatsScalarHalf));
                accumulation.setFinalResult(execSummaryStatsScalarHalf);
            } else {
                nativeOps.execSummaryStatsHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length, ((Variance) accumulation).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (accumulation.y() != null) {
            if (accumulation.isComplexAccumulation()) {
                nativeOps.execReduce3AllHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length, pointer, new LongPointerWrapper(pointer2), (IntPointer) pointer6, new LongPointerWrapper(pointer5));
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            } else if (z.isScalar()) {
                float execReduce3ScalarHalf = nativeOps.execReduce3ScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction));
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                z.assign(Float.valueOf(execReduce3ScalarHalf));
                accumulation.setFinalResult(execReduce3ScalarHalf);
            } else {
                nativeOps.execReduce3Half(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length);
                AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (z.isScalar()) {
            float execReduceScalarHalf = nativeOps.execReduceScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer7);
            AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            z.assign(Float.valueOf(execReduceScalarHalf));
            accumulation.setFinalResult(execReduceScalarHalf);
        } else {
            nativeOps.execReduceHalf(put, accumulation.opNum(), (ShortPointer) pointer3, pointer4, (ShortPointer) pointer7, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction), pointer8, iArr.length);
            AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
        }
        profilingHookOut(accumulation, profilingHookIn);
        return accumulation.z();
    }

    public INDArray exec(Accumulation accumulation, int... iArr) {
        long profilingHookIn = profilingHookIn(accumulation);
        checkForCompression(accumulation);
        validateDataType(Nd4j.dataType(), accumulation);
        Arrays.sort(iArr);
        validateDataType(Nd4j.dataType(), accumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        int[] maxShape = Shape.getMaxShape(new INDArray[]{accumulation.x(), accumulation.y()});
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= maxShape.length && iArr[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(iArr) + " contains element that higher then rank of op.X: [" + accumulation.x().rank() + "]");
            }
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] < 0) {
                int i3 = i2;
                iArr[i3] = iArr[i3] + accumulation.x().rank();
            }
        }
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Nd4jCuda.MAX_DIMENSION};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(maxShape, iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex) && ArrayUtil.prodLong(removeIndex) > 1 && accumulation.y() == null) {
            return accumulation.noOp();
        }
        INDArray iNDArray = null;
        if (accumulation.z() == null || accumulation.z() == accumulation.x()) {
            if (accumulation.isComplexAccumulation()) {
                iNDArray = Nd4j.create(accumulation.x().tensorssAlongDimension(iArr), accumulation.y().tensorssAlongDimension(iArr));
            } else {
                if (accumulation.y() != null && accumulation.x().tensorAlongDimension(0, iArr).lengthLong() != accumulation.y().lengthLong()) {
                    throw new ND4JIllegalStateException("Number of TADs along dimension doesn't match");
                }
                if (0.0d + Math.abs(accumulation.zeroDouble()) <= Nd4j.EPS_THRESHOLD) {
                    iNDArray = Nd4j.zeros(removeIndex);
                } else if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                    iNDArray = Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble());
                } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
                    iNDArray = Nd4j.valueArrayOf(removeIndex, accumulation.zeroFloat());
                } else if (accumulation.x().data().dataType() == DataBuffer.Type.HALF) {
                    iNDArray = Nd4j.valueArrayOf(removeIndex, accumulation.zeroHalf());
                }
            }
            accumulation.setZ(iNDArray);
        } else {
            if (accumulation.z().lengthLong() != ArrayUtil.prodLong(removeIndex)) {
                throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(accumulation.z().shape()) + "] doesn't match expected [" + Arrays.toString(removeIndex) + "]");
            }
            if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                accumulation.z().assign(Double.valueOf(accumulation.zeroDouble()));
            } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
                accumulation.z().assign(Float.valueOf(accumulation.zeroFloat()));
            } else if (accumulation.x().data().dataType() == DataBuffer.Type.HALF) {
                accumulation.z().assign(Float.valueOf(accumulation.zeroHalf()));
            }
            accumulation.z();
        }
        naiveExec(accumulation, iArr);
        profilingHookOut(accumulation, profilingHookIn);
        return accumulation.z();
    }

    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        long profilingHookIn = profilingHookIn(indexAccumulation);
        checkForCompression(indexAccumulation);
        validateDataType(Nd4j.dataType(), indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= indexAccumulation.x().rank() && iArr[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(iArr) + " contains element that higher then rank of op.X: [" + indexAccumulation.x().rank() + "]");
            }
        }
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] < 0) {
                int i3 = i2;
                iArr[i3] = iArr[i3] + indexAccumulation.x().rank();
            }
        }
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Nd4jCuda.MAX_DIMENSION};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(indexAccumulation.x().shape(), iArr);
        if (indexAccumulation.x().isVector() && indexAccumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return indexAccumulation.x();
        }
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        INDArray iNDArray = null;
        if (0.0d + Math.abs(indexAccumulation.zeroDouble()) <= Nd4j.EPS_THRESHOLD) {
            iNDArray = Nd4j.zeros(removeIndex);
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            iNDArray = Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroDouble());
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            iNDArray = Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroFloat());
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.HALF) {
            iNDArray = Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroHalf());
        }
        indexAccumulation.setZ(iNDArray);
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Nd4jCuda.MAX_DIMENSION};
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(indexAccumulation.opName());
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        Pointer retrieveHostPointer = indexAccumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = indexAccumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.z().shapeInfoDataBuffer());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(indexAccumulation.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), prepareAction);
        DoublePointer pointer3 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z(), prepareAction);
        IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), prepareAction);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer5, dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction)});
        Pointer pointer6 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(), prepareAction) : null;
        IntPointer pointer7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execIndexReduceDouble(put, indexAccumulation.opNum(), pointer, pointer2, (DoublePointer) pointer6, pointer3, pointer4, pointer7, iArr.length);
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execIndexReduceFloat(put, indexAccumulation.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer6, (FloatPointer) pointer3, pointer4, pointer7, iArr.length);
        } else {
            nativeOps.execIndexReduceHalf(put, indexAccumulation.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer6, (ShortPointer) pointer3, pointer4, pointer7, iArr.length);
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        profilingHookOut(indexAccumulation, profilingHookIn);
        return indexAccumulation.z();
    }

    public Op exec(Op op, int... iArr) {
        checkForCompression(op);
        Arrays.sort(iArr);
        return super.exec(op, iArr);
    }

    public Op exec(Op op) {
        checkForCompression(op);
        if ((op.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA || (op instanceof CopyOp)) {
            if (op.x() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.x());
            }
            if (op.y() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.y());
            }
            super.exec(op);
            if (op.z() == null) {
                return null;
            }
            AtomicAllocator.getInstance().tickHostWrite(op.z());
            return null;
        }
        if (op instanceof TransformOp) {
            invoke((TransformOp) op);
        } else if (op instanceof Accumulation) {
            invoke((Accumulation) op, (int[]) null);
        } else if (op instanceof ScalarOp) {
            invoke((ScalarOp) op);
        } else if (op instanceof BroadcastOp) {
            invoke((BroadcastOp) op);
        } else if (op instanceof IndexAccumulation) {
            invoke((IndexAccumulation) op, (int[]) null);
        }
        return op;
    }

    public INDArray execAndReturn(TransformOp transformOp) {
        checkForCompression(transformOp);
        invoke(transformOp);
        return transformOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(BroadcastOp broadcastOp) {
        long profilingHookIn = profilingHookIn(broadcastOp);
        checkForCompression(broadcastOp);
        validateDataType(Nd4j.dataType(), broadcastOp);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(broadcastOp.opName());
        }
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(broadcastOp.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), prepareAction);
        Pointer retrieveHostPointer = broadcastOp.y() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = broadcastOp.z() == null ? null : AddressRetriever.retrieveHostPointer(broadcastOp.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), broadcastOp.getDimension());
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer3 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        Pointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(broadcastOp.z(), broadcastOp.getDimension());
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(broadcastOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer3, pointer4, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction)});
        DoublePointer pointer5 = AtomicAllocator.getInstance().getPointer(broadcastOp.y(), prepareAction);
        IntPointer pointer6 = AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction);
        DoublePointer pointer7 = AtomicAllocator.getInstance().getPointer(broadcastOp.z(), prepareAction);
        IntPointer pointer8 = AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction);
        IntPointer pointer9 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(broadcastOp.getDimension()), prepareAction);
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(put, broadcastOp.opNum(), pointer, pointer2, pointer5, pointer6, pointer7, pointer8, pointer9, broadcastOp.getDimension().length);
        } else if (broadcastOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execBroadcastFloat(put, broadcastOp.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer5, pointer6, (FloatPointer) pointer7, pointer8, pointer9, broadcastOp.getDimension().length);
        } else {
            nativeOps.execBroadcastHalf(put, broadcastOp.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer5, pointer6, (ShortPointer) pointer7, pointer8, pointer9, broadcastOp.getDimension().length);
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        profilingHookOut(broadcastOp, profilingHookIn);
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(IndexAccumulation indexAccumulation, int[] iArr) {
        long profilingHookIn = profilingHookIn(indexAccumulation);
        if ((iArr == null || (iArr.length == 1 && iArr[0] == Integer.MAX_VALUE)) && (indexAccumulation.z() == indexAccumulation.x() || indexAccumulation.z() == null)) {
            indexAccumulation.setZ(Nd4j.scalar(0.0d));
        }
        checkForCompression(indexAccumulation);
        validateDataType(Nd4j.dataType(), indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(indexAccumulation.opName());
        }
        CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= indexAccumulation.x().rank() && iArr[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(iArr) + " contains element that higher then rank of op.X: [" + indexAccumulation.x().rank() + "]");
            }
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(indexAccumulation.z().isScalar() ? null : indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(indexAccumulation.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer3 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(), prepareAction) : null;
        Pointer retrieveHostPointer = indexAccumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = indexAccumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(indexAccumulation.z().shapeInfoDataBuffer());
        int[] iArr2 = iArr;
        if (iArr2 == null) {
            iArr2 = new int[]{0};
        }
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), iArr2);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(indexAccumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer4, dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction)});
        if (!indexAccumulation.z().isScalar() && iArr != null && iArr[0] != Integer.MAX_VALUE) {
            Arrays.sort(iArr);
            DoublePointer pointer5 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z(), prepareAction);
            IntPointer pointer6 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), prepareAction);
            IntPointer pointer7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
            if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execIndexReduceDouble(put, indexAccumulation.opNum(), pointer, pointer2, (DoublePointer) pointer3, pointer5, pointer6, pointer7, iArr.length);
            } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execIndexReduceFloat(put, indexAccumulation.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer3, (FloatPointer) pointer5, pointer6, pointer7, iArr.length);
            } else {
                nativeOps.execIndexReduceHalf(put, indexAccumulation.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer3, (ShortPointer) pointer5, pointer6, pointer7, iArr.length);
            }
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            double execIndexReduceScalarDouble = nativeOps.execIndexReduceScalarDouble(put, indexAccumulation.opNum(), pointer, pointer2, (DoublePointer) pointer3);
            indexAccumulation.setFinalResult((int) execIndexReduceScalarDouble);
            indexAccumulation.z().assign(Double.valueOf(execIndexReduceScalarDouble));
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            float execIndexReduceScalarFloat = nativeOps.execIndexReduceScalarFloat(put, indexAccumulation.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer3);
            indexAccumulation.setFinalResult((int) execIndexReduceScalarFloat);
            indexAccumulation.z().assign(Float.valueOf(execIndexReduceScalarFloat));
        } else {
            float execIndexReduceScalarHalf = nativeOps.execIndexReduceScalarHalf(put, indexAccumulation.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer3);
            indexAccumulation.setFinalResult((int) execIndexReduceScalarHalf);
            indexAccumulation.z().assign(Float.valueOf(execIndexReduceScalarHalf));
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, null, indexAccumulation.x(), indexAccumulation.y());
        profilingHookOut(indexAccumulation, profilingHookIn);
        return null;
    }

    protected CudaContext invoke(Accumulation accumulation, int[] iArr) {
        long profilingHookIn = profilingHookIn(accumulation);
        checkForCompression(accumulation);
        validateDataType(Nd4j.dataType(), accumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (iArr == null) {
            iArr = new int[]{Nd4jCuda.MAX_DIMENSION};
        }
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= accumulation.x().rank() && iArr[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(iArr) + " contains element that higher then rank of op.X: [" + accumulation.x().rank() + "]");
            }
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(accumulation.z(), accumulation.x(), accumulation.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(accumulation.opName());
        }
        Pointer retrieveHostPointer = accumulation.y() == null ? null : AddressRetriever.retrieveHostPointer(accumulation.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = accumulation.z() == null ? null : AddressRetriever.retrieveHostPointer(accumulation.z().shapeInfoDataBuffer());
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(accumulation.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(accumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer, dataBuffer == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction)});
        if (accumulation.y() != null) {
            Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(accumulation.y(), iArr);
            Pointer pointer2 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction);
            DataBuffer dataBuffer2 = (DataBuffer) tADOnlyShapeInfo2.getSecond();
            Pointer pointer3 = dataBuffer2 == null ? null : AtomicAllocator.getInstance().getPointer(dataBuffer2, prepareAction);
            put.put(12L, pointer2);
            put.put(13L, pointer3);
        }
        DoublePointer pointer4 = AtomicAllocator.getInstance().getPointer(accumulation.x(), prepareAction);
        IntPointer pointer5 = AtomicAllocator.getInstance().getPointer(accumulation.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer6 = accumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(accumulation.extraArgsDataBuff(), prepareAction) : null;
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return null;
        }
        INDArray iNDArray = null;
        if (0.0d + Math.abs(accumulation.zeroDouble()) <= Nd4j.EPS_THRESHOLD) {
            iNDArray = Nd4j.zeros(removeIndex);
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            iNDArray = Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble());
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            iNDArray = Nd4j.valueArrayOf(removeIndex, accumulation.zeroFloat());
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.HALF) {
            iNDArray = Nd4j.valueArrayOf(removeIndex, accumulation.zeroHalf());
        }
        accumulation.setZ(iNDArray);
        if (!accumulation.z().isScalar()) {
            DoublePointer pointer7 = AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction);
            IntPointer pointer8 = AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction);
            IntPointer pointer9 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
            if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (accumulation.y() != null) {
                    nativeOps.execReduce3Double(put, accumulation.opNum(), pointer4, pointer5, (DoublePointer) pointer6, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), pointer7, pointer8, pointer9, iArr.length);
                } else if (accumulation instanceof Variance) {
                    nativeOps.execSummaryStatsDouble(put, accumulation.opNum(), pointer4, pointer5, (DoublePointer) pointer6, pointer7, pointer8, pointer9, iArr.length, ((Variance) accumulation).isBiasCorrected());
                } else {
                    nativeOps.execReduceDouble(put, accumulation.opNum(), pointer4, pointer5, (DoublePointer) pointer6, pointer7, pointer8, pointer9, iArr.length);
                }
            } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (accumulation.y() != null) {
                    nativeOps.execReduce3Float(put, accumulation.opNum(), (FloatPointer) pointer4, pointer5, (FloatPointer) pointer6, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), (FloatPointer) pointer7, pointer8, pointer9, iArr.length);
                } else if (accumulation instanceof Variance) {
                    nativeOps.execSummaryStatsFloat(put, accumulation.opNum(), (FloatPointer) pointer4, pointer5, (FloatPointer) pointer6, (FloatPointer) pointer7, pointer8, pointer9, iArr.length, ((Variance) accumulation).isBiasCorrected());
                } else {
                    nativeOps.execReduceFloat(put, accumulation.opNum(), (FloatPointer) pointer4, pointer5, (FloatPointer) pointer6, (FloatPointer) pointer7, pointer8, pointer9, iArr.length);
                }
            } else if (accumulation.y() != null) {
                nativeOps.execReduce3Half(put, accumulation.opNum(), (ShortPointer) pointer4, pointer5, (ShortPointer) pointer6, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction), (ShortPointer) pointer7, pointer8, pointer9, iArr.length);
            } else if (accumulation instanceof Variance) {
                nativeOps.execSummaryStatsHalf(put, accumulation.opNum(), (ShortPointer) pointer4, pointer5, (ShortPointer) pointer6, (ShortPointer) pointer7, pointer8, pointer9, iArr.length, ((Variance) accumulation).isBiasCorrected());
            } else {
                nativeOps.execReduceHalf(put, accumulation.opNum(), (ShortPointer) pointer4, pointer5, (ShortPointer) pointer6, (ShortPointer) pointer7, pointer8, pointer9, iArr.length);
            }
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                accumulation.setFinalResult(nativeOps.execSummaryStatsScalarDouble(put, accumulation.opNum(), pointer4, pointer5, (DoublePointer) pointer6, ((Variance) accumulation).isBiasCorrected()));
            } else if (accumulation.y() != null) {
                accumulation.setFinalResult(nativeOps.execReduce3ScalarDouble(put, accumulation.opNum(), pointer4, pointer5, (DoublePointer) pointer6, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction)));
            } else {
                accumulation.setFinalResult(nativeOps.execReduceScalarDouble(put, accumulation.opNum(), pointer4, pointer5, (DoublePointer) pointer6));
            }
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (accumulation instanceof Variance) {
                accumulation.setFinalResult(nativeOps.execSummaryStatsScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer4, pointer5, (FloatPointer) pointer6, ((Variance) accumulation).isBiasCorrected()));
            } else if (accumulation.y() != null) {
                accumulation.setFinalResult(nativeOps.execReduce3ScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer4, pointer5, (FloatPointer) pointer6, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction)));
            } else {
                accumulation.setFinalResult(nativeOps.execReduceScalarFloat(put, accumulation.opNum(), (FloatPointer) pointer4, pointer5, (FloatPointer) pointer6));
            }
        } else if (accumulation instanceof Variance) {
            accumulation.setFinalResult(nativeOps.execSummaryStatsScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer4, pointer5, (ShortPointer) pointer6, ((Variance) accumulation).isBiasCorrected()));
        } else if (accumulation.y() != null) {
            accumulation.setFinalResult(nativeOps.execReduce3ScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer4, pointer5, (ShortPointer) pointer6, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction)));
        } else {
            accumulation.setFinalResult(nativeOps.execReduceScalarHalf(put, accumulation.opNum(), (ShortPointer) pointer4, pointer5, (ShortPointer) pointer6));
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
        profilingHookOut(accumulation, profilingHookIn);
        return prepareAction;
    }

    protected CudaContext intercept(ScalarOp scalarOp, int[] iArr) {
        long profilingHookIn = profilingHookIn(scalarOp);
        Arrays.sort(iArr);
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(scalarOp.z(), scalarOp.x(), scalarOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(scalarOp.opName());
        }
        Pointer retrieveHostPointer = scalarOp.y() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = scalarOp.z() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.z().shapeInfoDataBuffer());
        ShortPointer pointer = AtomicAllocator.getInstance().getPointer(scalarOp.x(), prepareAction);
        ShortPointer pointer2 = AtomicAllocator.getInstance().getPointer(scalarOp.y(), prepareAction);
        ShortPointer pointer3 = AtomicAllocator.getInstance().getPointer(scalarOp.z(), prepareAction);
        IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(scalarOp.x().shapeInfoDataBuffer(), prepareAction);
        IntPointer pointer5 = AtomicAllocator.getInstance().getPointer(scalarOp.z().shapeInfoDataBuffer(), prepareAction);
        Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(scalarOp.x(), iArr);
        Pointer retrieveHostPointer3 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
        Pointer pointer6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
        Pointer pointer7 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction);
        Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(scalarOp.z(), iArr);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, retrieveHostPointer3, pointer6, pointer7, AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getSecond(), prepareAction)});
        Pointer pointer8 = scalarOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(scalarOp.extraArgsDataBuff(), prepareAction) : null;
        IntPointer pointer9 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction);
        if (scalarOp.x().data().dataType() == DataBuffer.Type.HALF) {
            nativeOps.execScalarHalf(put, scalarOp.opNum(), pointer, pointer4, pointer3, pointer5, pointer2, (ShortPointer) pointer8, pointer9, iArr.length);
        } else if (scalarOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execScalarFloat(put, scalarOp.opNum(), (FloatPointer) pointer, pointer4, (FloatPointer) pointer3, pointer5, (FloatPointer) pointer2, (FloatPointer) pointer8, pointer9, iArr.length);
        } else if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execScalarDouble(put, scalarOp.opNum(), (DoublePointer) pointer, pointer4, (DoublePointer) pointer3, pointer5, (DoublePointer) pointer2, (DoublePointer) pointer8, pointer9, iArr.length);
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, scalarOp.z(), scalarOp.x(), scalarOp.y());
        profilingHookOut(scalarOp, profilingHookIn);
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(ScalarOp scalarOp) {
        long profilingHookIn = profilingHookIn(scalarOp);
        checkForCompression(scalarOp);
        validateDataType(Nd4j.dataType(), scalarOp);
        if (scalarOp.x().length() != scalarOp.z().length()) {
            throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(scalarOp.x().shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(scalarOp.z().shapeInfoDataBuffer().asInt()) + "]");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(scalarOp.opName());
        }
        if (scalarOp.getDimension() != null) {
            intercept(scalarOp, scalarOp.getDimension());
            return null;
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(scalarOp.z(), scalarOp.x(), scalarOp.y());
        Pointer retrieveHostPointer = scalarOp.y() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = scalarOp.z() == null ? null : AddressRetriever.retrieveHostPointer(scalarOp.z().shapeInfoDataBuffer());
        DoublePointer pointer = AtomicAllocator.getInstance().getPointer(scalarOp.x(), prepareAction);
        IntPointer pointer2 = AtomicAllocator.getInstance().getPointer(scalarOp.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer3 = scalarOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(scalarOp.extraArgsDataBuff(), prepareAction) : null;
        DoublePointer pointer4 = AtomicAllocator.getInstance().getPointer(scalarOp.z(), prepareAction);
        IntPointer pointer5 = AtomicAllocator.getInstance().getPointer(scalarOp.z().shapeInfoDataBuffer(), prepareAction);
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(scalarOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostPointer, retrieveHostPointer2, null, null});
        if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (scalarOp.x().elementWiseStride() < 1 || scalarOp.z().ordering() != scalarOp.x().ordering()) {
                nativeOps.execScalarDouble(put, scalarOp.opNum(), pointer, pointer2, pointer4, pointer5, scalarOp.scalar().doubleValue(), (DoublePointer) pointer3);
            } else {
                nativeOps.execScalarDouble(put, scalarOp.opNum(), pointer, scalarOp.x().elementWiseStride(), pointer4, scalarOp.z().elementWiseStride(), scalarOp.scalar().doubleValue(), (DoublePointer) pointer3, scalarOp.n());
            }
        } else if (scalarOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (scalarOp.x().elementWiseStride() < 1 || scalarOp.z().ordering() != scalarOp.x().ordering()) {
                nativeOps.execScalarFloat(put, scalarOp.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer4, pointer5, scalarOp.scalar().floatValue(), (FloatPointer) pointer3);
            } else {
                nativeOps.execScalarFloat(put, scalarOp.opNum(), (FloatPointer) pointer, scalarOp.x().elementWiseStride(), (FloatPointer) pointer4, scalarOp.z().elementWiseStride(), scalarOp.scalar().floatValue(), (FloatPointer) pointer3, scalarOp.n());
            }
        } else if (scalarOp.x().elementWiseStride() < 1 || scalarOp.z().ordering() != scalarOp.x().ordering()) {
            nativeOps.execScalarHalf(put, scalarOp.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer4, pointer5, scalarOp.scalar().floatValue(), (ShortPointer) pointer3);
        } else {
            nativeOps.execScalarHalf(put, scalarOp.opNum(), (ShortPointer) pointer, scalarOp.x().elementWiseStride(), (ShortPointer) pointer4, scalarOp.z().elementWiseStride(), scalarOp.scalar().floatValue(), (ShortPointer) pointer3, scalarOp.n());
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, scalarOp.z(), scalarOp.x(), scalarOp.y());
        profilingHookOut(scalarOp, profilingHookIn);
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext invoke(TransformOp transformOp) {
        long profilingHookIn = profilingHookIn(transformOp);
        checkForCompression(transformOp);
        validateDataType(Nd4j.dataType(), transformOp);
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (transformOp.opNum() == 7 && transformOp.y() != null && transformOp.y().isScalar()) {
            Nd4j.getExecutioner().commit();
            transformOp.setY(Nd4j.valueArrayOf(transformOp.x().shape(), transformOp.y().getDouble(0)));
            Nd4j.getExecutioner().commit();
        }
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(transformOp.z(), transformOp.x(), transformOp.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(transformOp.opName());
        }
        INDArray iNDArray = null;
        DoublePointer pointer = atomicAllocator.getPointer(transformOp.x(), prepareAction);
        IntPointer pointer2 = atomicAllocator.getPointer(transformOp.x().shapeInfoDataBuffer(), prepareAction);
        Pointer pointer3 = transformOp.extraArgs() != null ? atomicAllocator.getPointer(transformOp.extraArgsDataBuff(), prepareAction) : null;
        Pointer retrieveHostPointer = transformOp.y() == null ? null : AddressRetriever.retrieveHostPointer(transformOp.y().shapeInfoDataBuffer());
        Pointer retrieveHostPointer2 = transformOp.z() == null ? null : AddressRetriever.retrieveHostPointer(transformOp.z().shapeInfoDataBuffer());
        Pointer pointer4 = null;
        Pointer pointer5 = null;
        Pointer pointer6 = null;
        int[] iArr = null;
        if (transformOp.opNum() == 41 && transformOp.extraArgs() != null) {
            iArr = new int[((Integer) transformOp.extraArgs()[0]).intValue()];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = ((Integer) transformOp.extraArgs()[i + 1]).intValue();
            }
            for (int i2 = 0; i2 < iArr.length; i2++) {
                if (iArr[i2] < 0) {
                    int i3 = i2;
                    iArr[i3] = iArr[i3] + transformOp.x().rank();
                }
            }
            if (iArr.length == transformOp.x().rank()) {
                iArr = new int[]{Nd4jCuda.MAX_DIMENSION};
            }
            int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(transformOp.x().shape(), iArr);
            if (removeIndex.length == 1) {
                removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
            } else if (removeIndex.length == 0) {
                removeIndex = new int[]{1, 1};
            }
            iNDArray = Nd4j.zeros(removeIndex);
            retrieveHostPointer = atomicAllocator.getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction);
            DataBuffer constantBuffer = atomicAllocator.getConstantBuffer(iArr);
            pointer4 = atomicAllocator.getPointer(constantBuffer, prepareAction);
            pointer5 = atomicAllocator.getHostPointer(constantBuffer);
            pointer6 = atomicAllocator.getPointer(iNDArray, prepareAction);
        }
        Pointer pointer7 = null;
        Pointer pointer8 = null;
        Pointer pointer9 = null;
        Pointer pointer10 = null;
        Pointer pointer11 = null;
        Pointer pointer12 = null;
        if (transformOp.opNum() >= 38 && transformOp.opNum() <= 41) {
            if (transformOp.opNum() != 41) {
                Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(transformOp.x(), new int[]{0});
                Pair tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(transformOp.x(), new int[]{1});
                pointer7 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo.getFirst());
                pointer8 = atomicAllocator.getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction);
                pointer9 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo2.getFirst());
                pointer10 = atomicAllocator.getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction);
                DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
                pointer11 = dataBuffer == null ? null : atomicAllocator.getPointer(dataBuffer, prepareAction);
                DataBuffer dataBuffer2 = (DataBuffer) tADOnlyShapeInfo2.getSecond();
                pointer12 = dataBuffer2 == null ? null : atomicAllocator.getPointer(dataBuffer2, prepareAction);
            } else {
                Pair tADOnlyShapeInfo3 = tadManager.getTADOnlyShapeInfo(transformOp.z(), iArr);
                pointer7 = AddressRetriever.retrieveHostPointer((DataBuffer) tADOnlyShapeInfo3.getFirst());
                pointer8 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo3.getFirst(), prepareAction);
                DataBuffer dataBuffer3 = (DataBuffer) tADOnlyShapeInfo3.getSecond();
                pointer11 = dataBuffer3 == null ? null : atomicAllocator.getPointer(dataBuffer3, prepareAction);
            }
        }
        DoublePointer pointer13 = atomicAllocator.getPointer(transformOp.z(), prepareAction);
        IntPointer pointer14 = atomicAllocator.getPointer(transformOp.z().shapeInfoDataBuffer(), prepareAction);
        PointerPointer pointerPointer = this.extraz.get();
        Pointer[] pointerArr = new Pointer[19];
        pointerArr[0] = AddressRetriever.retrieveHostPointer(transformOp.x().shapeInfoDataBuffer());
        pointerArr[1] = prepareAction.getOldStream();
        pointerArr[2] = atomicAllocator.getDeviceIdPointer();
        pointerArr[3] = prepareAction.getBufferAllocation();
        pointerArr[4] = prepareAction.getBufferReduction();
        pointerArr[5] = prepareAction.getBufferScalar();
        pointerArr[6] = prepareAction.getBufferSpecial();
        pointerArr[7] = retrieveHostPointer;
        pointerArr[8] = retrieveHostPointer2;
        pointerArr[9] = pointer7;
        pointerArr[10] = pointer8;
        pointerArr[11] = pointer11;
        pointerArr[12] = pointer9;
        pointerArr[13] = pointer10;
        pointerArr[14] = pointer12;
        pointerArr[15] = pointer4;
        pointerArr[16] = pointer5;
        pointerArr[17] = pointer6;
        pointerArr[18] = new CudaPointer(iArr == null ? 0L : iArr.length);
        PointerPointer put = pointerPointer.put(pointerArr);
        if (transformOp.y() != null) {
            DoublePointer pointer15 = atomicAllocator.getPointer(transformOp.y(), prepareAction);
            IntPointer pointer16 = atomicAllocator.getPointer(transformOp.y().shapeInfoDataBuffer(), prepareAction);
            int elementWiseStride = transformOp.x().elementWiseStride();
            int elementWiseStride2 = transformOp.y().elementWiseStride();
            int elementWiseStride3 = transformOp.z().elementWiseStride();
            boolean isRowVector = transformOp.x().isRowVector();
            boolean isRowVector2 = transformOp.y().isRowVector();
            boolean isRowVector3 = transformOp.z().isRowVector();
            if (transformOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if ((elementWiseStride < 1 || elementWiseStride2 < 1 || elementWiseStride3 < 1 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) && !(elementWiseStride >= 1 && elementWiseStride2 == elementWiseStride && elementWiseStride3 == elementWiseStride && isRowVector && isRowVector2 && isRowVector3)) {
                    nativeOps.execPairwiseTransformDouble(put, transformOp.opNum(), pointer, pointer2, pointer15, pointer16, pointer13, pointer14, (DoublePointer) pointer3);
                } else {
                    nativeOps.execPairwiseTransformDouble(put, transformOp.opNum(), pointer, elementWiseStride, pointer15, elementWiseStride2, pointer13, elementWiseStride3, (DoublePointer) pointer3, transformOp.n());
                }
            } else if (transformOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if ((elementWiseStride < 1 || elementWiseStride2 < 1 || elementWiseStride != elementWiseStride2 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) && !(elementWiseStride >= 1 && elementWiseStride2 == elementWiseStride && elementWiseStride3 == elementWiseStride && isRowVector && isRowVector2 && isRowVector3)) {
                    nativeOps.execPairwiseTransformFloat(put, transformOp.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer15, pointer16, (FloatPointer) pointer13, pointer14, (FloatPointer) pointer3);
                } else {
                    nativeOps.execPairwiseTransformFloat(put, transformOp.opNum(), (FloatPointer) pointer, elementWiseStride, (FloatPointer) pointer15, elementWiseStride2, (FloatPointer) pointer13, elementWiseStride3, (FloatPointer) pointer3, transformOp.n());
                }
            } else if ((elementWiseStride < 1 || elementWiseStride2 < 1 || elementWiseStride != transformOp.y().elementWiseStride() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) && !(elementWiseStride >= 1 && elementWiseStride2 == elementWiseStride && elementWiseStride3 == elementWiseStride && isRowVector && isRowVector2 && isRowVector3)) {
                nativeOps.execPairwiseTransformHalf(put, transformOp.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer15, pointer16, (ShortPointer) pointer13, pointer14, (ShortPointer) pointer3);
            } else {
                nativeOps.execPairwiseTransformHalf(put, transformOp.opNum(), (ShortPointer) pointer, elementWiseStride, (ShortPointer) pointer15, elementWiseStride2, (ShortPointer) pointer13, elementWiseStride3, (ShortPointer) pointer3, transformOp.n());
            }
        } else if (transformOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.z().ordering() != transformOp.x().ordering()) {
                nativeOps.execTransformDouble(put, transformOp.opNum(), pointer, pointer2, pointer13, pointer14, (DoublePointer) pointer3);
            } else {
                nativeOps.execTransformDouble(put, transformOp.opNum(), pointer, transformOp.x().elementWiseStride(), pointer13, transformOp.z().elementWiseStride(), (DoublePointer) pointer3, transformOp.n());
            }
        } else if (transformOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.z().ordering() != transformOp.x().ordering()) {
                nativeOps.execTransformFloat(put, transformOp.opNum(), (FloatPointer) pointer, pointer2, (FloatPointer) pointer13, pointer14, (FloatPointer) pointer3);
            } else {
                nativeOps.execTransformFloat(put, transformOp.opNum(), (FloatPointer) pointer, transformOp.x().elementWiseStride(), (FloatPointer) pointer13, transformOp.z().elementWiseStride(), (FloatPointer) pointer3, transformOp.n());
            }
        } else if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.z().ordering() != transformOp.x().ordering()) {
            nativeOps.execTransformHalf(put, transformOp.opNum(), (ShortPointer) pointer, pointer2, (ShortPointer) pointer13, pointer14, (ShortPointer) pointer3);
        } else {
            nativeOps.execTransformHalf(put, transformOp.opNum(), (ShortPointer) pointer, transformOp.x().elementWiseStride(), (ShortPointer) pointer13, transformOp.z().elementWiseStride(), (ShortPointer) pointer3, transformOp.n());
        }
        AtomicAllocator.getInstance().registerAction(prepareAction, transformOp.z(), transformOp.x(), transformOp.y());
        if (pointer3 != null) {
            pointer3.address();
        }
        if (iNDArray != null) {
            iNDArray.elementWiseStride();
        }
        profilingHookOut(transformOp, profilingHookIn);
        return null;
    }

    protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
        DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4, false);
        batch.setParamsSurface(createInt);
        return createInt;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        DataBuffer buffer = getBuffer(batch);
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        IntPointer asIntPointer = new CudaPointer(AtomicAllocator.getInstance().getHostPointer(buffer)).asIntPointer();
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
        int maxIntArrays = batch.getSample().maxIntArrays();
        int maxIntArraySize = batch.getSample().maxIntArraySize();
        int batchLimit = (((5 * (Batch.getBatchLimit() * 16)) + (batch.getSample().maxIndexArguments() * (Batch.getBatchLimit() * 16))) + ((maxIntArrays * maxIntArraySize) * (Batch.getBatchLimit() * 16))) / (Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 2 : 1);
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            batchLimit *= 2;
        }
        int maxRealArguments = (batchLimit + (batch.getSample().maxRealArguments() * (Batch.getBatchLimit() * 16))) / (Nd4j.dataType() == DataBuffer.Type.FLOAT ? 2 : 1);
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            maxRealArguments /= 4;
        }
        int maxArguments = maxRealArguments + (batch.getSample().maxArguments() * Batch.getBatchLimit() * 16);
        for (int i = 0; i < batch.getNumAggregates(); i++) {
            Aggregate aggregate = (Aggregate) batch.getAggregates().get(i);
            asIntPointer.put(i * 5, aggregate.getArguments().size());
            asIntPointer.put(r0 + 1, aggregate.getShapes().size());
            asIntPointer.put(r0 + 2, aggregate.getIndexingArguments().size());
            asIntPointer.put(r0 + 3, aggregate.getRealArguments().size());
            asIntPointer.put(r0 + 4, aggregate.getIntArrayArguments().size());
            for (int i2 = 0; i2 < aggregate.getIndexingArguments().size(); i2++) {
                asIntPointer.put(r0 + (i * batch.getSample().maxIndexArguments()) + i2, ((Integer) aggregate.getIndexingArguments().get(i2)).intValue());
            }
            int i3 = maxIntArrays * maxIntArraySize;
            for (int i4 = 0; i4 < aggregate.getIntArrayArguments().size(); i4++) {
                int i5 = (i * i3) + (i4 * maxIntArraySize);
                if (aggregate.getIntArrayArguments().get(i4) != null) {
                    for (int i6 = 0; i6 < ((int[]) aggregate.getIntArrayArguments().get(i4)).length; i6++) {
                        asIntPointer.put(r0 + i5 + i6, ((int[]) aggregate.getIntArrayArguments().get(i4))[i6]);
                    }
                }
            }
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                FloatPointer floatPointer = new FloatPointer(asIntPointer);
                for (int i7 = 0; i7 < aggregate.getRealArguments().size(); i7++) {
                    floatPointer.put(batchLimit + (i * aggregate.maxRealArguments()) + i7, ((Number) aggregate.getRealArguments().get(i7)).floatValue());
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                DoublePointer doublePointer = new DoublePointer(asIntPointer);
                for (int i8 = 0; i8 < aggregate.getRealArguments().size(); i8++) {
                    doublePointer.put(batchLimit + (i * aggregate.maxRealArguments()) + i8, ((Number) aggregate.getRealArguments().get(i8)).doubleValue());
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                ShortPointer shortPointer = new ShortPointer(asIntPointer);
                for (int i9 = 0; i9 < aggregate.getRealArguments().size(); i9++) {
                    shortPointer.put(batchLimit + (i * aggregate.maxRealArguments()) + i9, BaseDataBuffer.fromFloat(((Number) aggregate.getRealArguments().get(i9)).floatValue()));
                }
            }
            PointerPointer pointerPointer = new PointerPointer(asIntPointer);
            for (int i10 = 0; i10 < aggregate.getArguments().size(); i10++) {
                int maxArguments2 = maxRealArguments + (i * batch.getSample().maxArguments());
                if (aggregate.getArguments().get(i10) != null) {
                    pointerPointer.put(maxArguments2 + i10, AtomicAllocator.getInstance().getPointer((INDArray) aggregate.getArguments().get(i10), cudaContext));
                    AtomicAllocator.getInstance().getAllocationPoint((INDArray) aggregate.getArguments().get(i10)).tickDeviceWrite();
                }
            }
            for (int i11 = 0; i11 < aggregate.getShapes().size(); i11++) {
                int maxShapes = maxArguments + (i * batch.getSample().maxShapes());
                if (aggregate.getShapes().get(i11) != null) {
                    pointerPointer.put(maxShapes + i11, AtomicAllocator.getInstance().getPointer((DataBuffer) aggregate.getShapes().get(i11), cudaContext));
                    AtomicAllocator.getInstance().getAllocationPoint((DataBuffer) aggregate.getShapes().get(i11)).tickDeviceWrite();
                }
            }
        }
        allocationPoint.tickHostWrite();
        PointerPointer pointerPointer2 = new PointerPointer(32L);
        pointerPointer2.put(0L, (Pointer) null);
        pointerPointer2.put(1L, cudaContext.getOldStream());
        pointerPointer2.put(2L, new CudaPointer(Math.min(batch.getNumAggregates(), CudaEnvironment.getInstance().getConfiguration().getMaximumGridSize())));
        pointerPointer2.put(3L, new CudaPointer(batch.getSample().getThreadsPerInstance()));
        pointerPointer2.put(4L, new CudaPointer(batch.getSample().getSharedMemorySize()));
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execAggregateBatchFloat(pointerPointer2, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(buffer, cudaContext));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execAggregateBatchDouble(pointerPointer2, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(buffer, cudaContext));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execAggregateBatchHalf(pointerPointer2, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(buffer, cudaContext));
        }
        allocationPoint.tickHostWrite();
    }

    public void exec(List<Aggregate> list) {
        if (list.size() == 0) {
            return;
        }
        Iterator it = Batch.getBatches(list, 8192).iterator();
        while (it.hasNext()) {
            exec((Batch) it.next());
        }
        ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncOldStream();
    }

    public void exec(Aggregate aggregate) {
        int size = aggregate.getArguments().size();
        int size2 = aggregate.getShapes().size();
        int size3 = aggregate.getIndexingArguments().size();
        int size4 = aggregate.getIntArrayArguments().size();
        int size5 = aggregate.getRealArguments().size();
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        PointerPointer pointerPointer = new PointerPointer(32L);
        pointerPointer.put(0L, (Pointer) null);
        pointerPointer.put(1L, cudaContext.getOldStream());
        pointerPointer.put(2L, new CudaPointer(1L));
        pointerPointer.put(3L, new CudaPointer(aggregate.getThreadsPerInstance()));
        pointerPointer.put(4L, new CudaPointer(aggregate.getSharedMemorySize()));
        long[] jArr = new long[size];
        for (int i = 0; i < size; i++) {
            jArr[i] = aggregate.getArguments().get(i) == null ? 0L : AtomicAllocator.getInstance().getPointer((INDArray) aggregate.getArguments().get(i), cudaContext).address();
            if (aggregate.getArguments().get(i) != null) {
                AtomicAllocator.getInstance().getAllocationPoint((INDArray) aggregate.getArguments().get(i)).tickDeviceWrite();
            }
        }
        PointerPointer pointerPointer2 = new PointerPointer(AtomicAllocator.getInstance().getPointer(AllocationUtils.getPointersBuffer(jArr), cudaContext));
        long[] jArr2 = new long[size2];
        for (int i2 = 0; i2 < size2; i2++) {
            jArr2[i2] = aggregate.getShapes().get(i2) == null ? 0L : AtomicAllocator.getInstance().getPointer((DataBuffer) aggregate.getShapes().get(i2), cudaContext).address();
            if (aggregate.getShapes().get(i2) != null) {
                AtomicAllocator.getInstance().getAllocationPoint((DataBuffer) aggregate.getShapes().get(i2)).tickDeviceWrite();
            }
        }
        PointerPointer pointerPointer3 = new PointerPointer(AtomicAllocator.getInstance().getPointer(AllocationUtils.getPointersBuffer(jArr2), cudaContext));
        long[] jArr3 = new long[size4];
        for (int i3 = 0; i3 < size4; i3++) {
            if (aggregate.getIntArrayArguments().get(i3) != null) {
                jArr3[i3] = AtomicAllocator.getInstance().getPointer(Nd4j.getDataBufferFactory().createInt((int[]) aggregate.getIntArrayArguments().get(i3)), cudaContext).address();
            }
        }
        PointerPointer pointerPointer4 = new PointerPointer(AtomicAllocator.getInstance().getPointer(AllocationUtils.getPointersBuffer(jArr3), cudaContext));
        int[] iArr = new int[size3];
        for (int i4 = 0; i4 < size3; i4++) {
            iArr[i4] = ((Integer) aggregate.getIndexingArguments().get(i4)).intValue();
        }
        DataBuffer createInt = Nd4j.getDataBufferFactory().createInt(iArr);
        double[] dArr = new double[size5];
        for (int i5 = 0; i5 < size5; i5++) {
            dArr[i5] = ((Number) aggregate.getRealArguments().get(i5)).doubleValue();
        }
        INDArray create = Nd4j.create(dArr);
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execAggregateFloat(pointerPointer, aggregate.opNum(), pointerPointer2, size, pointerPointer3, size2, AtomicAllocator.getInstance().getPointer(createInt, cudaContext), size3, pointerPointer4, size4, AtomicAllocator.getInstance().getPointer(create.data(), cudaContext), size5);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execAggregateDouble(pointerPointer, aggregate.opNum(), pointerPointer2, size, pointerPointer3, size2, AtomicAllocator.getInstance().getPointer(createInt, cudaContext), size3, pointerPointer4, size4, AtomicAllocator.getInstance().getPointer(create.data(), cudaContext), size5);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execAggregateHalf(pointerPointer, aggregate.opNum(), pointerPointer2, size, pointerPointer3, size2, AtomicAllocator.getInstance().getPointer(createInt, cudaContext), size3, pointerPointer4, size4, AtomicAllocator.getInstance().getPointer(create.data(), cudaContext), size5);
        }
    }

    public INDArray exec(RandomOp randomOp) {
        return exec(randomOp, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp randomOp, Random random) {
        long profilingHookIn = profilingHookIn(randomOp);
        checkForCompression(randomOp);
        validateDataType(Nd4j.dataType(), randomOp);
        if (random.getStateBuffer() == null) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(randomOp.opName());
        }
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(randomOp.z(), randomOp.x(), randomOp.y());
        PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(randomOp.z().shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        if (randomOp.x() == null || randomOp.y() == null || randomOp.z() == null) {
            if (randomOp.x() == null || randomOp.z() == null) {
                if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                    nativeOps.execRandomFloat(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
                } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                    nativeOps.execRandomDouble(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
                } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                    nativeOps.execRandomHalf(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execRandomFloat(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.x(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.x().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execRandomDouble(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.x(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.x().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
            } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                nativeOps.execRandomHalf(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.x(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.x().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
            }
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execRandomFloat(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.x(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.x().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execRandomDouble(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.x(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.x().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execRandomHalf(put, randomOp.opNum(), random.getStatePointer(), AtomicAllocator.getInstance().getPointer(randomOp.x(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.x().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.y(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.y().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.z().shapeInfoDataBuffer(), prepareAction), AtomicAllocator.getInstance().getPointer(randomOp.extraArgsDataBuff(), prepareAction));
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, randomOp.z(), randomOp.x(), randomOp.y());
        profilingHookOut(randomOp, profilingHookIn);
        return randomOp.z();
    }

    public synchronized Properties getEnvironmentInformation() {
        if (this.properties == null) {
            Properties environmentInformation = super.getEnvironmentInformation();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < nativeOps.getAvailableDevices(); i++) {
                HashMap hashMap = new HashMap();
                CudaPointer cudaPointer = new CudaPointer(i);
                hashMap.put("cuda.deviceName", nativeOps.getDeviceName(cudaPointer));
                hashMap.put("cuda.freeMemory", Long.valueOf(nativeOps.getDeviceFreeMemory(cudaPointer)));
                hashMap.put("cuda.totalMemory", Long.valueOf(nativeOps.getDeviceTotalMemory(cudaPointer)));
                hashMap.put("cuda.deviceMajor", Long.valueOf(nativeOps.getDeviceMajor(cudaPointer)));
                hashMap.put("cuda.deviceMinor", Long.valueOf(nativeOps.getDeviceMinor(cudaPointer)));
                arrayList.add(i, hashMap);
            }
            environmentInformation.put("backend", "CUDA");
            environmentInformation.put("cuda.availableDevices", Integer.valueOf(nativeOps.getAvailableDevices()));
            environmentInformation.put("cuda.devicesInformation", arrayList);
            environmentInformation.put("blas.vendor", Nd4j.factory().blas().getBlasVendor().toString());
            environmentInformation.put("memory.free", Long.valueOf(Pointer.maxBytes() - Pointer.totalBytes()));
            environmentInformation.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
            this.properties = environmentInformation;
        } else {
            List list = (List) this.properties.get("cuda.devicesInformation");
            for (int i2 = 0; i2 < nativeOps.getAvailableDevices(); i2++) {
                Map map = (Map) list.get(i2);
                CudaPointer cudaPointer2 = new CudaPointer(i2);
                map.put("cuda.freeMemory", Long.valueOf(nativeOps.getDeviceFreeMemory(cudaPointer2)));
                map.put("cuda.totalMemory", Long.valueOf(nativeOps.getDeviceTotalMemory(cudaPointer2)));
            }
            this.properties.put("cuda.devicesInformation", list);
            this.properties.put("memory.free", Long.valueOf(Pointer.maxBytes() - Pointer.totalBytes()));
            this.properties.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
        }
        return this.properties;
    }

    public TADManager getTADManager() {
        return tadManager;
    }

    public void printEnvironmentInformation() {
        super.printEnvironmentInformation();
        for (Map map : (List) getEnvironmentInformation().get("cuda.devicesInformation")) {
            log.info("Device Name: [{}]; CC: [{}.{}]; Total/free memory: [{}]", new Object[]{map.get("cuda.deviceName"), map.get("cuda.deviceMajor"), map.get("cuda.deviceMinor"), map.get("cuda.totalMemory")});
        }
    }

    public void commit() {
        ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncOldStream();
        ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncSpecialStream();
    }

    public INDArray thresholdEncode(INDArray iNDArray, double d, Integer num) {
        DataBuffer data = iNDArray.data();
        int length = (int) ((data.length() / Nd4jCuda.MAX_NUM_THREADS) + (data.length() % ((long) Nd4jCuda.MAX_NUM_THREADS) == 0 ? 0 : 1));
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        DataBuffer createInt = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(length + 1, true) : Nd4j.getDataBufferFactory().createInt(length + 1, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(1L, cudaContext.getOldStream());
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1Float(put, AtomicAllocator.getInstance().getPointer(data), data.length(), AtomicAllocator.getInstance().getPointer(createInt), (float) d);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1Double(put, AtomicAllocator.getInstance().getPointer(data), data.length(), AtomicAllocator.getInstance().getPointer(createInt), (float) d);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1Half(put, AtomicAllocator.getInstance().getPointer(data), data.length(), AtomicAllocator.getInstance().getPointer(createInt), (float) d);
        }
        AtomicAllocator.getInstance().getAllocationPoint(createInt).tickDeviceWrite();
        int i = createInt.getInt(0L);
        if (i < 2) {
            return null;
        }
        if (num != null && i > num.intValue()) {
            i = num.intValue();
            createInt.put(0L, i);
        }
        DataBuffer createInt2 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(4 + i, false) : Nd4j.getDataBufferFactory().createInt(4 + i, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        AtomicAllocator.getInstance().getAllocationPoint(createInt2).tickHostWrite();
        createInt2.put(0L, i);
        createInt2.put(1L, (int) data.length());
        createInt2.put(2L, Float.floatToIntBits((float) d));
        AtomicAllocator.getInstance().getAllocationPoint(createInt2).tickHostWrite();
        createInt2.put(3L, 0);
        int i2 = length;
        int i3 = 0;
        ArrayList arrayList = new ArrayList();
        do {
            int max = Math.max(1, (int) Math.ceil(i2 / (2.0f * 512)));
            if (length > 1) {
                i3++;
            }
            i2 = max;
        } while (i2 > 1);
        long[] jArr = new long[i3];
        int i4 = 0;
        int i5 = length;
        DataBuffer createDouble = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createDouble(jArr.length, false) : Nd4j.getDataBufferFactory().createDouble(jArr.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        do {
            int max2 = Math.max(1, (int) Math.ceil(i5 / (2.0f * 512)));
            if (max2 > 1) {
                DataBuffer createInt3 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(max2, false) : Nd4j.getDataBufferFactory().createInt(max2, false, Nd4j.getMemoryManager().getCurrentWorkspace());
                arrayList.add(createInt3);
                int i6 = i4;
                i4++;
                jArr[i6] = AtomicAllocator.getInstance().getPointer(createInt3).address();
            }
            i5 = max2;
        } while (i5 > 1);
        AtomicAllocator.getInstance().memcpyBlocking(createDouble, new LongPointer(jArr), jArr.length * 8, 0L);
        put.put(2L, AtomicAllocator.getInstance().getPointer(createDouble));
        DataBuffer createInt4 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(length, true) : Nd4j.getDataBufferFactory().createInt(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP2Int(put, AtomicAllocator.getInstance().getPointer(createInt), length, AtomicAllocator.getInstance().getPointer(createInt4));
        AtomicAllocator.getInstance().getAllocationPoint(createInt4).tickDeviceWrite();
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3Float(put, AtomicAllocator.getInstance().getPointer(data), AtomicAllocator.getInstance().getPointer(createInt4), data.length(), AtomicAllocator.getInstance().getPointer(createInt2));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3Double(put, AtomicAllocator.getInstance().getPointer(data), AtomicAllocator.getInstance().getPointer(createInt4), data.length(), AtomicAllocator.getInstance().getPointer(createInt2));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3Half(put, AtomicAllocator.getInstance().getPointer(data), AtomicAllocator.getInstance().getPointer(createInt4), data.length(), AtomicAllocator.getInstance().getPointer(createInt2));
        }
        AtomicAllocator.getInstance().getAllocationPoint(createInt2).tickDeviceWrite();
        AtomicAllocator.getInstance().getAllocationPoint(data).tickDeviceWrite();
        put.address();
        createDouble.address();
        arrayList.getClass();
        return Nd4j.createArrayFromShapeBuffer(createInt2, iNDArray.shapeInfoDataBuffer());
    }

    public INDArray thresholdEncode(INDArray iNDArray, double d) {
        return thresholdEncode(iNDArray, d, null);
    }

    public INDArray thresholdDecode(INDArray iNDArray, INDArray iNDArray2) {
        DataBuffer data = iNDArray.data();
        if (data.dataType() != DataBuffer.Type.INT) {
            throw new UnsupportedOperationException();
        }
        long j = data.getInt(0L);
        long j2 = data.getInt(1L);
        if (iNDArray2.lengthLong() != j2) {
            throw new ND4JIllegalStateException("originalLength [" + j2 + "] stored in encoded array doesn't match target length [" + iNDArray2.lengthLong() + "]");
        }
        DataBuffer data2 = iNDArray2.data();
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(1L, cudaContext.getOldStream());
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.decodeThresholdFloat(put, AtomicAllocator.getInstance().getPointer(data), j, AtomicAllocator.getInstance().getPointer(data2));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.decodeThresholdDouble(put, AtomicAllocator.getInstance().getPointer(data), j, AtomicAllocator.getInstance().getPointer(data2));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.decodeThresholdHalf(put, AtomicAllocator.getInstance().getPointer(data), j, AtomicAllocator.getInstance().getPointer(data2));
        }
        AtomicAllocator.getInstance().getAllocationPoint(data2).tickDeviceWrite();
        return iNDArray2;
    }

    public long bitmapEncode(INDArray iNDArray, INDArray iNDArray2, double d) {
        long encodeBitmapHalf;
        long lengthLong = iNDArray.lengthLong();
        if (iNDArray2.data().length() != (lengthLong / 16) + 5) {
            throw new ND4JIllegalStateException("Length of target array should be " + ((lengthLong / 16) + 5));
        }
        if (iNDArray2.data().dataType() != DataBuffer.Type.INT) {
            throw new ND4JIllegalStateException("Target array should have INT dataType");
        }
        DataBuffer data = iNDArray2.data();
        data.put(0L, (int) lengthLong);
        data.put(1L, (int) lengthLong);
        data.put(2L, Float.floatToIntBits((float) d));
        data.put(3L, 1);
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray, new INDArray[0]);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(iNDArray), prepareAction.getOldStream(), prepareAction.getBufferScalar(), prepareAction.getBufferReduction()});
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            encodeBitmapHalf = nativeOps.encodeBitmapFloat(put, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), lengthLong, AtomicAllocator.getInstance().getPointer(data, prepareAction), (float) d);
        } else if (iNDArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            encodeBitmapHalf = nativeOps.encodeBitmapDouble(put, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), lengthLong, AtomicAllocator.getInstance().getPointer(data, prepareAction), (float) d);
        } else {
            if (iNDArray.data().dataType() != DataBuffer.Type.HALF) {
                throw new ND4JIllegalStateException("Unknown dataType " + iNDArray.data().dataType());
            }
            encodeBitmapHalf = nativeOps.encodeBitmapHalf(put, AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction), lengthLong, AtomicAllocator.getInstance().getPointer(data, prepareAction), (float) d);
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, iNDArray, new INDArray[0]);
        AtomicAllocator.getInstance().getAllocationPoint(data).tickDeviceWrite();
        return encodeBitmapHalf;
    }

    public INDArray bitmapDecode(INDArray iNDArray, INDArray iNDArray2) {
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray2, new INDArray[0]);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(iNDArray2), prepareAction.getOldStream(), prepareAction.getBufferScalar(), prepareAction.getBufferReduction()});
        if (iNDArray2.data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.decodeBitmapFloat(put, AtomicAllocator.getInstance().getPointer(iNDArray.data(), prepareAction), iNDArray2.lengthLong(), AtomicAllocator.getInstance().getPointer(iNDArray2, prepareAction));
        } else if (iNDArray2.data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.decodeBitmapDouble(put, AtomicAllocator.getInstance().getPointer(iNDArray.data(), prepareAction), iNDArray2.lengthLong(), AtomicAllocator.getInstance().getPointer(iNDArray2, prepareAction));
        } else {
            if (iNDArray2.data().dataType() != DataBuffer.Type.HALF) {
                throw new ND4JIllegalStateException("Unknown dataType " + iNDArray2.data().dataType());
            }
            nativeOps.decodeBitmapHalf(put, AtomicAllocator.getInstance().getPointer(iNDArray.data(), prepareAction), iNDArray2.lengthLong(), AtomicAllocator.getInstance().getPointer(iNDArray2, prepareAction));
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, iNDArray2, new INDArray[0]);
        return iNDArray2;
    }

    public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
        if (this.customOps == null) {
            String allCustomOps = nativeOps.getAllCustomOps();
            if (allCustomOps == null || allCustomOps.isEmpty()) {
                log.warn("No customs ops available!");
                this.customOps = Collections.emptyMap();
                return this.customOps;
            }
            HashMap hashMap = new HashMap();
            for (String str : allCustomOps.split(";")) {
                if (str != null && !str.isEmpty()) {
                    String[] split = str.split(":");
                    hashMap.put(split[0], CustomOpDescriptor.builder().hash(Long.valueOf(split[1]).longValue()).numInputs(Integer.valueOf(split[2]).intValue()).numOutputs(Integer.valueOf(split[3]).intValue()).allowsInplace(Integer.valueOf(split[4]).intValue() == 1).numTArgs(Integer.valueOf(split[5]).intValue()).numIArgs(Integer.valueOf(split[6]).intValue()).build());
                }
            }
            this.customOps = Collections.unmodifiableMap(hashMap);
        }
        return this.customOps;
    }

    protected int[] getShapeFromPointer(IntPointer intPointer) {
        int i = intPointer.get(0L);
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = intPointer.get(i2 + 1);
        }
        return iArr;
    }

    public List<int[]> calculateOutputShape(@NonNull CustomOp customOp) {
        if (customOp == null) {
            throw new NullPointerException("op");
        }
        Nd4j.getExecutioner().commit();
        customOp.opName().toLowerCase();
        long opHash = customOp.opHash();
        ArrayList arrayList = new ArrayList();
        PointerPointer pointerPointer = new PointerPointer(customOp.inputArguments().length);
        PointerPointer pointerPointer2 = new PointerPointer(customOp.inputArguments().length);
        int i = 0;
        for (INDArray iNDArray : customOp.inputArguments()) {
            pointerPointer.put(i, iNDArray.data().addressPointer());
            int i2 = i;
            i++;
            pointerPointer2.put(i2, iNDArray.shapeInfoDataBuffer().addressPointer());
        }
        IntPointer intPointer = customOp.iArgs().length > 0 ? new IntPointer(customOp.iArgs().length) : null;
        int i3 = 0;
        for (int i4 : customOp.iArgs()) {
            int i5 = i3;
            i3++;
            intPointer.put(i5, i4);
        }
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            FloatPointer floatPointer = customOp.tArgs().length > 0 ? new FloatPointer(customOp.tArgs().length) : null;
            int i6 = 0;
            for (double d : customOp.tArgs()) {
                int i7 = i6;
                i6++;
                floatPointer.put(i7, (float) d);
            }
            Nd4jCuda.ShapeList shapeList = (Nd4jCuda.ShapeList) nativeOps.calculateOutputShapesFloat((PointerPointer) null, opHash, pointerPointer, pointerPointer2, customOp.inputArguments().length, floatPointer, customOp.tArgs().length, intPointer, customOp.iArgs().length);
            if (shapeList == null) {
                throw new RuntimeException();
            }
            for (int i8 = 0; i8 < shapeList.size(); i8++) {
                arrayList.add(getShapeFromPointer(new PagedPointer(shapeList.at(i8)).asIntPointer()));
            }
            nativeOps.deleteShapeList(shapeList);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            DoublePointer doublePointer = customOp.tArgs().length > 0 ? new DoublePointer(customOp.tArgs().length) : null;
            int i9 = 0;
            int length = customOp.tArgs().length;
            for (int i10 = 0; i10 < length; i10++) {
                int i11 = i9;
                i9++;
                doublePointer.put(i11, (float) r0[i10]);
            }
            Nd4jCuda.ShapeList shapeList2 = (Nd4jCuda.ShapeList) nativeOps.calculateOutputShapesDouble((PointerPointer) null, opHash, pointerPointer, pointerPointer2, customOp.inputArguments().length, doublePointer, customOp.tArgs().length, intPointer, customOp.iArgs().length);
            if (shapeList2 == null) {
                throw new RuntimeException();
            }
            for (int i12 = 0; i12 < shapeList2.size(); i12++) {
                arrayList.add(getShapeFromPointer(new PagedPointer(shapeList2.at(i12)).asIntPointer()));
            }
            nativeOps.deleteShapeList(shapeList2);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            ShortPointer shortPointer = customOp.tArgs().length > 0 ? new ShortPointer(customOp.tArgs().length) : null;
            int i13 = 0;
            for (double d2 : customOp.tArgs()) {
                int i14 = i13;
                i13++;
                shortPointer.put(i14, ArrayUtil.toHalf((float) d2));
            }
            Nd4jCuda.ShapeList shapeList3 = (Nd4jCuda.ShapeList) nativeOps.calculateOutputShapesHalf((PointerPointer) null, opHash, pointerPointer, pointerPointer2, customOp.inputArguments().length, shortPointer, customOp.tArgs().length, intPointer, customOp.iArgs().length);
            if (shapeList3 == null) {
                throw new RuntimeException();
            }
            for (int i15 = 0; i15 < shapeList3.size(); i15++) {
                arrayList.add(getShapeFromPointer(new PagedPointer(shapeList3.at(i15)).asIntPointer()));
            }
            nativeOps.deleteShapeList(shapeList3);
        }
        return arrayList;
    }

    public void exec(CustomOp customOp) {
        Nd4j.getExecutioner().commit();
        if (customOp.opName().equalsIgnoreCase("im2col")) {
            DataBuffer.Type dataType = Nd4j.dataType();
            INDArray iNDArray = customOp.inputArguments()[0];
            INDArray iNDArray2 = customOp.outputArguments()[0];
            CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray2, iNDArray);
            if (this.extraz.get() == null) {
                this.extraz.set(new PointerPointer(32L));
            }
            PointerPointer put = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(iNDArray.shapeInfoDataBuffer()), prepareAction.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), null, AddressRetriever.retrieveHostPointer(iNDArray2.shapeInfoDataBuffer())});
            DoublePointer pointer = AtomicAllocator.getInstance().getPointer(iNDArray, prepareAction);
            DoublePointer pointer2 = AtomicAllocator.getInstance().getPointer(iNDArray2, prepareAction);
            IntPointer pointer3 = AtomicAllocator.getInstance().getPointer(iNDArray.shapeInfoDataBuffer(), prepareAction);
            IntPointer pointer4 = AtomicAllocator.getInstance().getPointer(iNDArray2.shapeInfoDataBuffer(), prepareAction);
            double d = 0.0d;
            if (customOp.tArgs() != null && customOp.tArgs().length > 0) {
                d = customOp.tArgs()[0];
            }
            DoublePointer pointer5 = AtomicAllocator.getInstance().getPointer(Nd4j.getConstantHandler().getConstantBuffer(new double[]{customOp.iArgs()[0], customOp.iArgs()[1], customOp.iArgs()[2], customOp.iArgs()[3], customOp.iArgs()[4], customOp.iArgs()[5], customOp.iArgs()[6], customOp.iArgs()[7], customOp.iArgs()[8], d}), prepareAction);
            if (dataType == DataBuffer.Type.DOUBLE) {
                nativeOps.execTransformDouble(put, 37, pointer, pointer3, pointer2, pointer4, pointer5);
            } else if (dataType == DataBuffer.Type.FLOAT) {
                nativeOps.execTransformFloat(put, 37, (FloatPointer) pointer, pointer3, (FloatPointer) pointer2, pointer4, (FloatPointer) pointer5);
            } else if (dataType == DataBuffer.Type.HALF) {
                nativeOps.execTransformHalf(put, 37, (ShortPointer) pointer, pointer3, (ShortPointer) pointer2, pointer4, (ShortPointer) pointer5);
            }
            AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, iNDArray2, iNDArray);
            return;
        }
        if (customOp.opName().equalsIgnoreCase("col2im")) {
            DataBuffer.Type dataType2 = Nd4j.dataType();
            INDArray iNDArray3 = customOp.inputArguments()[0];
            INDArray iNDArray4 = customOp.outputArguments()[0];
            CudaContext prepareAction2 = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray4, iNDArray3);
            if (this.extraz.get() == null) {
                this.extraz.set(new PointerPointer(32L));
            }
            PointerPointer put2 = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(iNDArray3.shapeInfoDataBuffer()), prepareAction2.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction2.getBufferAllocation(), prepareAction2.getBufferReduction(), prepareAction2.getBufferScalar(), prepareAction2.getBufferSpecial(), null, AddressRetriever.retrieveHostPointer(iNDArray4.shapeInfoDataBuffer())});
            DoublePointer pointer6 = AtomicAllocator.getInstance().getPointer(iNDArray3, prepareAction2);
            DoublePointer pointer7 = AtomicAllocator.getInstance().getPointer(iNDArray4, prepareAction2);
            IntPointer pointer8 = AtomicAllocator.getInstance().getPointer(iNDArray3.shapeInfoDataBuffer(), prepareAction2);
            IntPointer pointer9 = AtomicAllocator.getInstance().getPointer(iNDArray4.shapeInfoDataBuffer(), prepareAction2);
            DoublePointer pointer10 = AtomicAllocator.getInstance().getPointer(Nd4j.getConstantHandler().getConstantBuffer(new double[]{customOp.iArgs()[0], customOp.iArgs()[1], customOp.iArgs()[2], customOp.iArgs()[3], customOp.iArgs()[4], customOp.iArgs()[5], customOp.iArgs()[6], customOp.iArgs()[7]}), prepareAction2);
            if (dataType2 == DataBuffer.Type.DOUBLE) {
                nativeOps.execTransformDouble(put2, 36, pointer6, pointer8, pointer7, pointer9, pointer10);
            } else if (dataType2 == DataBuffer.Type.FLOAT) {
                nativeOps.execTransformFloat(put2, 36, (FloatPointer) pointer6, pointer8, (FloatPointer) pointer7, pointer9, (FloatPointer) pointer10);
            } else if (dataType2 == DataBuffer.Type.HALF) {
                nativeOps.execTransformHalf(put2, 36, (ShortPointer) pointer6, pointer8, (ShortPointer) pointer7, pointer9, (ShortPointer) pointer10);
            }
            AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction2, iNDArray4, iNDArray3);
            return;
        }
        if (customOp.opName().equalsIgnoreCase("pooling2d")) {
            DataBuffer.Type dataType3 = Nd4j.dataType();
            INDArray iNDArray5 = customOp.inputArguments()[0];
            INDArray iNDArray6 = customOp.outputArguments()[0];
            CudaContext prepareAction3 = AtomicAllocator.getInstance().getFlowController().prepareAction(iNDArray6, iNDArray5);
            if (this.extraz.get() == null) {
                this.extraz.set(new PointerPointer(32L));
            }
            PointerPointer put3 = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(iNDArray5.shapeInfoDataBuffer()), prepareAction3.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), prepareAction3.getBufferAllocation(), prepareAction3.getBufferReduction(), prepareAction3.getBufferScalar(), prepareAction3.getBufferSpecial(), null, AddressRetriever.retrieveHostPointer(iNDArray6.shapeInfoDataBuffer())});
            DoublePointer pointer11 = AtomicAllocator.getInstance().getPointer(iNDArray5, prepareAction3);
            DoublePointer pointer12 = AtomicAllocator.getInstance().getPointer(iNDArray6, prepareAction3);
            IntPointer pointer13 = AtomicAllocator.getInstance().getPointer(iNDArray5.shapeInfoDataBuffer(), prepareAction3);
            IntPointer pointer14 = AtomicAllocator.getInstance().getPointer(iNDArray6.shapeInfoDataBuffer(), prepareAction3);
            DoublePointer pointer15 = AtomicAllocator.getInstance().getPointer(Nd4j.getConstantHandler().getConstantBuffer(new double[]{customOp.iArgs()[0], customOp.iArgs()[1], customOp.iArgs()[2], customOp.iArgs()[3], customOp.iArgs()[4], customOp.iArgs()[5], customOp.iArgs()[6], customOp.iArgs()[7], customOp.iArgs()[8]}), prepareAction3);
            if (dataType3 == DataBuffer.Type.DOUBLE) {
                nativeOps.execTransformDouble(put3, 71, pointer11, pointer13, pointer12, pointer14, pointer15);
            } else if (dataType3 == DataBuffer.Type.FLOAT) {
                nativeOps.execTransformFloat(put3, 71, (FloatPointer) pointer11, pointer13, (FloatPointer) pointer12, pointer14, (FloatPointer) pointer15);
            } else if (dataType3 == DataBuffer.Type.HALF) {
                nativeOps.execTransformHalf(put3, 71, (ShortPointer) pointer11, pointer13, (ShortPointer) pointer12, pointer14, (ShortPointer) pointer15);
            }
            AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction3, iNDArray6, iNDArray5);
            return;
        }
        Nd4j.getExecutioner().commit();
        long profilingHookIn = profilingHookIn(customOp);
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put4 = this.extraz.get().put(new Pointer[]{new CudaPointer(1L), cudaContext.getOldStream(), cudaContext.getBufferScalar(), cudaContext.getBufferReduction()});
        INDArray[] outputArguments = customOp.outputArguments();
        INDArray[] inputArguments = customOp.inputArguments();
        if (outputArguments.length == 0 && !customOp.isInplaceCall()) {
            throw new ND4JIllegalStateException("You can't execute non-inplace CustomOp without outputs being specified");
        }
        customOp.opName().toLowerCase();
        long opHash = customOp.opHash();
        PointerPointer pointerPointer = new PointerPointer(inputArguments.length * 2);
        PointerPointer pointerPointer2 = new PointerPointer(inputArguments.length * 2);
        int i = 0;
        for (INDArray iNDArray7 : inputArguments) {
            Pointer hostPointer = AtomicAllocator.getInstance().getHostPointer(iNDArray7.shapeInfoDataBuffer());
            pointerPointer2.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArray7));
            pointerPointer.put(i, hostPointer);
            Pointer pointer16 = AtomicAllocator.getInstance().getPointer(iNDArray7.shapeInfoDataBuffer(), cudaContext);
            pointerPointer2.put(i + inputArguments.length, AtomicAllocator.getInstance().getPointer(iNDArray7, cudaContext));
            pointerPointer.put(i + inputArguments.length, pointer16);
            if (customOp.isInplaceCall()) {
                AtomicAllocator.getInstance().getAllocationPoint(iNDArray7).tickHostWrite();
            }
            i++;
        }
        PointerPointer pointerPointer3 = new PointerPointer(outputArguments.length * 2);
        PointerPointer pointerPointer4 = new PointerPointer(outputArguments.length * 2);
        int i2 = 0;
        for (INDArray iNDArray8 : outputArguments) {
            pointerPointer4.put(i2, AtomicAllocator.getInstance().getHostPointer(iNDArray8));
            pointerPointer3.put(i2, AtomicAllocator.getInstance().getHostPointer(iNDArray8.shapeInfoDataBuffer()));
            pointerPointer4.put(i2 + outputArguments.length, AtomicAllocator.getInstance().getPointer(iNDArray8, cudaContext));
            pointerPointer3.put(i2 + outputArguments.length, AtomicAllocator.getInstance().getPointer(iNDArray8.shapeInfoDataBuffer(), cudaContext));
            AtomicAllocator.getInstance().getAllocationPoint(iNDArray8).tickHostWrite();
            i2++;
        }
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            FloatPointer floatPointer = customOp.tArgs().length > 0 ? new FloatPointer(customOp.tArgs().length) : null;
            IntPointer intPointer = customOp.iArgs().length > 0 ? new IntPointer(customOp.iArgs().length) : null;
            int i3 = 0;
            for (double d2 : customOp.tArgs()) {
                int i4 = i3;
                i3++;
                floatPointer.put(i4, (float) d2);
            }
            int i5 = 0;
            for (int i6 : customOp.iArgs()) {
                int i7 = i5;
                i5++;
                intPointer.put(i7, i6);
            }
            OpStatus byNumber = OpStatus.byNumber(nativeOps.execCustomOpFloat(put4, opHash, pointerPointer2, pointerPointer, inputArguments.length, pointerPointer4, pointerPointer3, outputArguments.length, floatPointer, customOp.tArgs().length, intPointer, customOp.iArgs().length, customOp.isInplaceCall()));
            if (byNumber != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber);
            }
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            DoublePointer doublePointer = customOp.tArgs().length > 0 ? new DoublePointer(customOp.tArgs().length) : null;
            IntPointer intPointer2 = customOp.iArgs().length > 0 ? new IntPointer(customOp.iArgs().length) : null;
            int i8 = 0;
            for (double d3 : customOp.tArgs()) {
                int i9 = i8;
                i8++;
                doublePointer.put(i9, d3);
            }
            for (int i10 : customOp.iArgs()) {
                int i11 = i8;
                i8++;
                intPointer2.put(i11, i10);
            }
            OpStatus byNumber2 = OpStatus.byNumber(nativeOps.execCustomOpDouble(put4, opHash, pointerPointer2, pointerPointer, inputArguments.length, pointerPointer4, pointerPointer3, outputArguments.length, doublePointer, customOp.tArgs().length, intPointer2, customOp.iArgs().length, customOp.isInplaceCall()));
            if (byNumber2 != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber2);
            }
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            ShortPointer shortPointer = customOp.tArgs().length > 0 ? new ShortPointer(customOp.tArgs().length) : null;
            IntPointer intPointer3 = customOp.iArgs().length > 0 ? new IntPointer(customOp.iArgs().length) : null;
            int i12 = 0;
            for (double d4 : customOp.tArgs()) {
                int i13 = i12;
                i12++;
                shortPointer.put(i13, ArrayUtil.toHalf((float) d4));
            }
            int i14 = 0;
            for (int i15 : customOp.iArgs()) {
                int i16 = i14;
                i14++;
                intPointer3.put(i16, i15);
            }
            OpStatus byNumber3 = OpStatus.byNumber(nativeOps.execCustomOpHalf(put4, opHash, pointerPointer2, pointerPointer, inputArguments.length, pointerPointer4, pointerPointer3, outputArguments.length, shortPointer, customOp.tArgs().length, intPointer3, customOp.iArgs().length, customOp.isInplaceCall()));
            if (byNumber3 != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber3);
            }
        }
        profilingHookOut(customOp, profilingHookIn);
    }

    public void enableDebugMode(boolean z) {
        nativeOps.enableDebugMode(z);
    }

    public void enableVerboseMode(boolean z) {
        nativeOps.enableVerboseMode(z);
    }

    public void registerGraph(long j, Pointer pointer) {
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.registerGraphFloat((PointerPointer) null, j, pointer);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.registerGraphDouble((PointerPointer) null, j, pointer);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.registerGraphHalf((PointerPointer) null, j, pointer);
        }
    }

    public Map<String, INDArray> executeGraph(long j, Map<String, INDArray> map) {
        commit();
        PointerPointer pointerPointer = new PointerPointer(map.size() * 2);
        PointerPointer pointerPointer2 = new PointerPointer(map.size() * 2);
        IntPointer intPointer = new IntPointer(map.size());
        int i = 0;
        ArrayList arrayList = new ArrayList(map.keySet());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            INDArray iNDArray = map.get((String) it.next());
            pointerPointer.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArray));
            pointerPointer2.put(i, AtomicAllocator.getInstance().getHostPointer(iNDArray.shapeInfoDataBuffer()));
            intPointer.put(i, i);
            i++;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            Nd4jCuda.FloatVariablesSet floatVariablesSet = (Nd4jCuda.FloatVariablesSet) nativeOps.executeStoredGraphFloat((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
            OpStatus byNumber = OpStatus.byNumber(floatVariablesSet.status());
            if (byNumber != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber);
            }
            for (int i2 = 0; i2 < floatVariablesSet.size(); i2++) {
                Nd4jCuda.FloatVariable at = floatVariablesSet.at(i2);
                int id = at.id();
                at.index();
                IntPointer shapeInfo = at.getNDArray().shapeInfo();
                FloatPointer buffer = at.getNDArray().buffer();
                int[] iArr = new int[(shapeInfo.get(0L) * 2) + 4];
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    iArr[i3] = shapeInfo.get(i3);
                }
                INDArray create = Nd4j.create(Shape.shapeOf(iArr), Shape.stridesOf(iArr), 0L, Shape.order(iArr));
                Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(create), buffer, ArrayUtil.prod(r0) * Nd4j.sizeOfDataType());
                AtomicAllocator.getInstance().getAllocationPoint(create).tickHostWrite();
                linkedHashMap.put(arrayList.get(id), create);
            }
            nativeOps.deleteVariablesSetFloat(floatVariablesSet);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            Nd4jCuda.DoubleVariablesSet doubleVariablesSet = (Nd4jCuda.DoubleVariablesSet) nativeOps.executeStoredGraphDouble((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
            OpStatus byNumber2 = OpStatus.byNumber(doubleVariablesSet.status());
            if (byNumber2 != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber2);
            }
            for (int i4 = 0; i4 < doubleVariablesSet.size(); i4++) {
                Nd4jCuda.DoubleVariable at2 = doubleVariablesSet.at(i4);
                int id2 = at2.id();
                at2.index();
                IntPointer shapeInfo2 = at2.getNDArray().shapeInfo();
                DoublePointer buffer2 = at2.getNDArray().buffer();
                int[] iArr2 = new int[(shapeInfo2.get(0L) * 2) + 4];
                for (int i5 = 0; i5 < iArr2.length; i5++) {
                    iArr2[i5] = shapeInfo2.get(i5);
                }
                INDArray create2 = Nd4j.create(Shape.shapeOf(iArr2), Shape.stridesOf(iArr2), 0L, Shape.order(iArr2));
                Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(create2), buffer2, ArrayUtil.prod(r0) * Nd4j.sizeOfDataType());
                AtomicAllocator.getInstance().getAllocationPoint(create2).tickHostWrite();
                linkedHashMap.put(arrayList.get(id2), create2);
            }
            nativeOps.deleteVariablesSetDouble(doubleVariablesSet);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            Nd4jCuda.DoubleVariablesSet doubleVariablesSet2 = (Nd4jCuda.DoubleVariablesSet) nativeOps.executeStoredGraphHalf((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
            OpStatus byNumber3 = OpStatus.byNumber(doubleVariablesSet2.status());
            if (byNumber3 != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber3);
            }
            for (int i6 = 0; i6 < doubleVariablesSet2.size(); i6++) {
                Nd4jCuda.DoubleVariable at3 = doubleVariablesSet2.at(i6);
                int id3 = at3.id();
                at3.index();
                IntPointer shapeInfo3 = at3.getNDArray().shapeInfo();
                DoublePointer buffer3 = at3.getNDArray().buffer();
                int[] iArr3 = new int[(shapeInfo3.get(0L) * 2) + 4];
                for (int i7 = 0; i7 < iArr3.length; i7++) {
                    iArr3[i7] = shapeInfo3.get(i7);
                }
                INDArray create3 = Nd4j.create(Shape.shapeOf(iArr3), Shape.stridesOf(iArr3), 0L, Shape.order(iArr3));
                Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(create3), buffer3, ArrayUtil.prod(r0) * Nd4j.sizeOfDataType());
                AtomicAllocator.getInstance().getAllocationPoint(create3).tickHostWrite();
                linkedHashMap.put(arrayList.get(id3), create3);
            }
            nativeOps.deleteVariablesSetHalf(doubleVariablesSet2);
        }
        return linkedHashMap;
    }

    public void forgetGraph(long j) {
        nativeOps.unregisterGraph((PointerPointer) null, j);
    }

    public void setElementsThreshold(int i) {
        nativeOps.setElementThreshold(i);
    }

    public void setTadThreshold(int i) {
        nativeOps.setTADThreshold(i);
    }
}
