package ai.djl.mxnet.engine;

import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.NDUtils;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.recurrent.RNN;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/djl/mxnet/engine/MxNDArrayEx.class */
public class MxNDArrayEx implements NDArrayEx {
    private MxNDArray array;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MxNDArrayEx(MxNDArray mxNDArray) {
        this.array = mxNDArray;
    }

    private Shape deriveBroadcastedShape(Shape shape, Shape shape2) {
        long[] jArr = new long[Math.max(shape.dimension(), shape2.dimension())];
        long length = jArr.length - shape.dimension();
        long length2 = jArr.length - shape2.dimension();
        for (int i = 0; i < jArr.length; i++) {
            long j = i >= length ? shape.get(Math.toIntExact(i - length)) : 1L;
            long j2 = i >= length2 ? shape2.get(Math.toIntExact(i - length2)) : 1L;
            if (j == j2) {
                jArr[i] = j;
            } else {
                if (j != 1 && j2 != 1) {
                    throw new IllegalArgumentException("operands could not be broadcast together with shapes " + shape + " " + shape2);
                }
                jArr[i] = j == 1 ? j2 : j;
            }
        }
        return new Shape(jArr);
    }

    public NDArray rdiv(Number number) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("scalar", number.toString());
        return getManager().invoke("_rdiv_scalar", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray rdiv(NDArray nDArray) {
        return nDArray.div(this.array);
    }

    public NDArray rdivi(Number number) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("scalar", number.toString());
        getManager().invoke("_rdiv_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, mxOpParams);
        return this.array;
    }

    public NDArray rdivi(NDArray nDArray) {
        getManager().invoke("elemwise_div", new NDArray[]{nDArray, this.array}, new NDArray[]{this.array}, (PairList<String, ?>) null);
        return this.array;
    }

    public NDArray rsub(Number number) {
        return this.array.sub(number).neg();
    }

    public NDArray rsub(NDArray nDArray) {
        return this.array.sub(nDArray).neg();
    }

    public NDArray rsubi(Number number) {
        return this.array.subi(number).negi();
    }

    public NDArray rsubi(NDArray nDArray) {
        return this.array.subi(nDArray).negi();
    }

    public NDArray rmod(Number number) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("scalar", number.toString());
        return getManager().invoke("_npi_rmod_scalar", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray rmod(NDArray nDArray) {
        return nDArray.mod(this.array);
    }

    public NDArray rmodi(Number number) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("scalar", number.toString());
        getManager().invoke("_npi_rmod_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, mxOpParams);
        return this.array;
    }

    public NDArray rmodi(NDArray nDArray) {
        getManager().invoke("_npi_mod", new NDArray[]{nDArray, this.array}, new NDArray[]{this.array}, (PairList<String, ?>) null);
        return this.array;
    }

    public NDArray rpow(Number number) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("scalar", number.toString());
        return getManager().invoke("_npi_rpower_scalar", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray rpowi(Number number) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("scalar", number.toString());
        getManager().invoke("_npi_rpower_scalar", new NDArray[]{this.array}, new NDArray[]{this.array}, mxOpParams);
        return this.array;
    }

    public NDArray relu() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "relu");
        return getManager().invoke("_npx_activation", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray sigmoid() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "sigmoid");
        return getManager().invoke("_npx_activation", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray tanh() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "tanh");
        return getManager().invoke("_npx_activation", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray softPlus() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "softrelu");
        return getManager().invoke("_npx_activation", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray softSign() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "softsign");
        return getManager().invoke("_npx_activation", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray leakyRelu(float f) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "leaky");
        mxOpParams.addParam("slope", f);
        return getManager().invoke("_npx_leaky_relu", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray elu(float f) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "elu");
        mxOpParams.addParam("slope", f);
        return getManager().invoke("_npx_leaky_relu", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray selu() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "selu");
        return getManager().invoke("_npx_leaky_relu", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray gelu() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "gelu");
        return getManager().invoke("_npx_leaky_relu", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray maxPool(Shape shape, Shape shape2, Shape shape3, boolean z) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("kernel", shape);
        mxOpParams.add("pool_type", "max");
        mxOpParams.addParam("stride", shape2);
        mxOpParams.addParam("pad", shape3);
        mxOpParams.add("pooling_convention", z ? "full" : "valid");
        return getManager().invoke("_npx_pooling", getArray(), mxOpParams);
    }

    public NDArray globalMaxPool() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("kernel", getGlobalPoolingShapes(1L));
        mxOpParams.add("pad", getGlobalPoolingShapes(0L));
        mxOpParams.add("pool_type", "max");
        mxOpParams.addParam("global_pool", true);
        NDArray invoke = getManager().invoke("_npx_pooling", getArray(), mxOpParams);
        Throwable th = null;
        try {
            try {
                NDArray reshape = invoke.reshape(new long[]{invoke.getShape().size(new int[]{0}), invoke.getShape().size(new int[]{1})});
                if (invoke != null) {
                    if (0 != 0) {
                        try {
                            invoke.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        invoke.close();
                    }
                }
                return reshape;
            } finally {
            }
        } catch (Throwable th3) {
            if (invoke != null) {
                if (th != null) {
                    try {
                        invoke.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    invoke.close();
                }
            }
            throw th3;
        }
    }

    public NDArray avgPool(Shape shape, Shape shape2, Shape shape3, boolean z, boolean z2) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("kernel", shape);
        mxOpParams.add("pool_type", "avg");
        mxOpParams.addParam("stride", shape2);
        mxOpParams.addParam("pad", shape3);
        mxOpParams.add("pooling_convention", z ? "full" : "valid");
        mxOpParams.addParam("count_include_pad", z2);
        return getManager().invoke("_npx_pooling", getArray(), mxOpParams);
    }

    public NDArray globalAvgPool() {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("kernel", getGlobalPoolingShapes(1L));
        mxOpParams.add("pad", getGlobalPoolingShapes(0L));
        mxOpParams.add("pool_type", "avg");
        mxOpParams.addParam("global_pool", true);
        NDArray invoke = getManager().invoke("_npx_pooling", getArray(), mxOpParams);
        Throwable th = null;
        try {
            try {
                NDArray reshape = invoke.reshape(new long[]{invoke.getShape().size(new int[]{0}), invoke.getShape().size(new int[]{1})});
                if (invoke != null) {
                    if (0 != 0) {
                        try {
                            invoke.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        invoke.close();
                    }
                }
                return reshape;
            } finally {
            }
        } catch (Throwable th3) {
            if (invoke != null) {
                if (th != null) {
                    try {
                        invoke.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    invoke.close();
                }
            }
            throw th3;
        }
    }

    public NDArray lpPool(float f, Shape shape, Shape shape2, Shape shape3, boolean z) {
        if (((int) f) != f) {
            throw new IllegalArgumentException("float type of normType is not supported in MXNet engine, please use integer instead");
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("p_value", (int) f);
        mxOpParams.addParam("kernel", shape);
        mxOpParams.add("pool_type", "lp");
        mxOpParams.addParam("stride", shape2);
        mxOpParams.addParam("pad", shape3);
        mxOpParams.add("pooling_convention", z ? "full" : "valid");
        return getManager().invoke("_npx_pooling", getArray(), mxOpParams);
    }

    public NDArray globalLpPool(float f) {
        if (((int) f) != f) {
            throw new IllegalArgumentException("float type of normType is not supported in MXNet engine, please use integer instead");
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("pool_type", "lp");
        mxOpParams.addParam("p_value", (int) f);
        mxOpParams.addParam("global_pool", true);
        NDArray invoke = getManager().invoke("_npx_pooling", getArray(), mxOpParams);
        Throwable th = null;
        try {
            try {
                NDArray reshape = invoke.reshape(new long[]{invoke.getShape().size(new int[]{0}), invoke.getShape().size(new int[]{1})});
                if (invoke != null) {
                    if (0 != 0) {
                        try {
                            invoke.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        invoke.close();
                    }
                }
                return reshape;
            } finally {
            }
        } catch (Throwable th3) {
            if (invoke != null) {
                if (th != null) {
                    try {
                        invoke.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    invoke.close();
                }
            }
            throw th3;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void adadeltaUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        NDArray nDArray = (NDArray) nDList.get(0);
        NDArray nDArray2 = (NDArray) nDList.get(1);
        NDArray nDArray3 = (NDArray) nDList.get(2);
        NDArray nDArray4 = (NDArray) nDList.get(3);
        NDManager newBaseManager = NDManager.newBaseManager();
        Throwable th = null;
        try {
            try {
                newBaseManager.tempAttachAll(new NDResource[]{nDList, nDList2});
                nDArray2.muli(Float.valueOf(f2));
                if (f3 > 0.0f) {
                    nDArray2 = nDArray2.clip(Float.valueOf(-f3), Float.valueOf(f3));
                }
                nDArray2.addi(nDArray.mul(Float.valueOf(f)));
                nDArray3.muli(Float.valueOf(f4)).addi(nDArray2.square().mul(Float.valueOf(1.0f - f4)));
                NDArray mul = nDArray4.add(Float.valueOf(f5)).sqrt().div(nDArray3.add(Float.valueOf(f5)).sqrt()).mul(nDArray2);
                nDArray4.muli(Float.valueOf(f4)).addi(mul.square().mul(Float.valueOf(1.0f - f4)));
                nDArray.subi(mul);
                if (newBaseManager != null) {
                    if (0 == 0) {
                        newBaseManager.close();
                        return;
                    }
                    try {
                        newBaseManager.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (newBaseManager != null) {
                if (th != null) {
                    try {
                        newBaseManager.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    newBaseManager.close();
                }
            }
            throw th4;
        }
    }

    public void adagradUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("lr", f);
        mxOpParams.addParam("wd", f2);
        mxOpParams.addParam("rescale_grad", f3);
        mxOpParams.addParam("clip_gradient", f4);
        mxOpParams.addParam("epsilon", f5);
        getManager().invoke("adagrad_update", nDList, nDList2, mxOpParams);
    }

    public void adamUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, float f6, float f7, boolean z) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("lr", f);
        mxOpParams.addParam("wd", f2);
        mxOpParams.addParam("rescale_grad", f3);
        mxOpParams.addParam("clip_gradient", f4);
        mxOpParams.addParam("beta1", f5);
        mxOpParams.addParam("beta2", f6);
        mxOpParams.addParam("epsilon", f7);
        mxOpParams.addParam("lazy_update", z);
        getManager().invoke("adam_update", nDList, nDList2, mxOpParams);
    }

    public void rmspropUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, float f6, float f7, boolean z) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("lr", f);
        mxOpParams.addParam("wd", f2);
        mxOpParams.addParam("rescale_grad", f3);
        mxOpParams.addParam("clip_gradient", f4);
        mxOpParams.addParam("gamma1", f5);
        mxOpParams.addParam("epsilon", f7);
        if (!z) {
            getManager().invoke("rmsprop_update", nDList, nDList2, mxOpParams);
        } else {
            mxOpParams.addParam("gamma2", f6);
            getManager().invoke("rmspropalex_update", nDList, nDList2, mxOpParams);
        }
    }

    public void nagUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("lr", f);
        mxOpParams.addParam("wd", f2);
        mxOpParams.addParam("rescale_grad", f3);
        mxOpParams.addParam("clip_gradient", f4);
        mxOpParams.addParam("momentum", f5);
        getManager().invoke("nag_mom_update", nDList, nDList2, mxOpParams);
    }

    public void sgdUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, boolean z) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("lr", f);
        mxOpParams.addParam("wd", f2);
        mxOpParams.addParam("rescale_grad", f3);
        mxOpParams.addParam("clip_gradient", f4);
        mxOpParams.addParam("lazy_update", z);
        if (f5 == 0.0f) {
            getManager().invoke("sgd_update", nDList, nDList2, mxOpParams);
        } else {
            mxOpParams.addParam("momentum", f5);
            getManager().invoke("sgd_mom_update", nDList, nDList2, mxOpParams);
        }
    }

    public NDList convolution(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, Shape shape, Shape shape2, Shape shape3, int i) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("kernel", nDArray2.getShape().slice(2));
        mxOpParams.addParam("stride", shape);
        mxOpParams.addParam("pad", shape2);
        mxOpParams.addParam("dilate", shape3);
        mxOpParams.addParam("num_group", i);
        mxOpParams.addParam("num_filter", nDArray2.getShape().get(0));
        NDList nDList = new NDList(new NDArray[]{nDArray, nDArray2});
        if (nDArray3 != null) {
            mxOpParams.add("no_bias", false);
            nDList.add(nDArray3);
        } else {
            mxOpParams.add("no_bias", true);
        }
        return getManager().invoke("_npx_convolution", nDList, mxOpParams);
    }

    public NDList deconvolution(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, Shape shape, Shape shape2, Shape shape3, Shape shape4, int i) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("kernel", nDArray2.getShape().slice(2));
        mxOpParams.addParam("stride", shape);
        mxOpParams.addParam("pad", shape2);
        mxOpParams.addParam("adj", shape3);
        mxOpParams.addParam("dilate", shape4);
        mxOpParams.addParam("num_group", i);
        mxOpParams.addParam("num_filter", nDArray2.getShape().get(0));
        NDList nDList = new NDList(new NDArray[]{nDArray, nDArray2});
        if (nDArray3 != null) {
            mxOpParams.add("no_bias", false);
            nDList.add(nDArray3);
        } else {
            mxOpParams.add("no_bias", true);
        }
        return getManager().invoke("_npx_deconvolution", nDList, mxOpParams);
    }

    public NDList linear(NDArray nDArray, NDArray nDArray2, NDArray nDArray3) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("num_hidden", nDArray2.size(0));
        mxOpParams.addParam("flatten", false);
        mxOpParams.addParam("no_bias", nDArray3 == null);
        NDList nDList = new NDList(new NDArray[]{nDArray, nDArray2});
        if (nDArray3 != null) {
            nDList.add(nDArray3);
        }
        return getManager().invoke("_npx_fully_connected", nDList, mxOpParams);
    }

    public NDList embedding(NDArray nDArray, NDArray nDArray2, SparseFormat sparseFormat) {
        if (!sparseFormat.equals(SparseFormat.DENSE) && !sparseFormat.equals(SparseFormat.ROW_SPARSE)) {
            throw new IllegalArgumentException("MXNet only supports row sparse");
        }
        MxOpParams mxOpParams = new MxOpParams();
        long j = nDArray2.getShape().get(0);
        long j2 = nDArray2.getShape().get(1);
        mxOpParams.addParam("input_dim", j);
        mxOpParams.addParam("output_dim", j2);
        mxOpParams.addParam("sparse_grad", sparseFormat.getValue());
        return getManager().invoke("_npx_embedding", new NDList(new NDArray[]{nDArray, nDArray2}), mxOpParams);
    }

    public NDList prelu(NDArray nDArray, NDArray nDArray2) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("act_type", "prelu");
        return getManager().invoke("_npx_leaky_relu", new NDList(new NDArray[]{nDArray, nDArray2}), mxOpParams);
    }

    public NDList dropout(NDArray nDArray, float f, boolean z) {
        if (z != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of dropout in MXNet should align with the mode of GradientCollector");
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("p", f);
        return getManager().invoke("_npx_dropout", new NDList(new NDArray[]{nDArray}), mxOpParams);
    }

    public NDList layerNorm(NDArray nDArray, Shape shape, NDArray nDArray2, NDArray nDArray3, float f) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("axis", -1);
        mxOpParams.addParam("eps", f);
        return new NDList(new NDArray[]{((NDArray) getManager().invoke("_npx_layer_norm", new NDList(new NDArray[]{nDArray.reshape(nDArray.getShape().slice(0, Math.toIntExact(nDArray.getShape().dimension() - shape.dimension())).add(new long[]{shape.size()})), nDArray2.reshape(new long[]{shape.size()}), nDArray3.reshape(new long[]{shape.size()})}), mxOpParams).get(0)).reshape(nDArray.getShape())});
    }

    public NDList batchNorm(NDArray nDArray, NDArray nDArray2, NDArray nDArray3, NDArray nDArray4, NDArray nDArray5, int i, float f, float f2, boolean z) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("axis", i);
        mxOpParams.addParam("fix_gamma", nDArray4 == null);
        mxOpParams.addParam("eps", f2);
        mxOpParams.addParam("momentum", f);
        if (z != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of batchNorm in MXNet should align with the mode of GradientCollector");
        }
        return getManager().invoke("_npx_batch_norm", new NDList(new NDArray[]{nDArray, nDArray4, nDArray5, nDArray2, nDArray3}), mxOpParams);
    }

    public NDList rnn(NDArray nDArray, NDArray nDArray2, NDList nDList, boolean z, int i, RNN.Activation activation, double d, boolean z2, boolean z3, boolean z4) {
        int i2 = i * (z ? 4 : 2) * (z3 ? 2 : 1);
        Preconditions.checkArgument(nDList.size() == i2, "The size of Params is incorrect expect " + i2 + " parameters but got " + nDList.size());
        if (z2 != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of rnn in MXNet should align with the mode of GradientCollector");
        }
        if (z4) {
            nDArray = nDArray.swapAxes(0, 1);
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("p", d);
        mxOpParams.addParam("state_size", nDArray2.getShape().tail());
        mxOpParams.addParam("num_layers", i);
        mxOpParams.addParam("bidirectional", z3);
        mxOpParams.addParam("state_outputs", true);
        mxOpParams.addParam("mode", activation == RNN.Activation.TANH ? "rnn_tanh" : "rnn_relu");
        NDList nDList2 = new NDList();
        nDList2.add(nDArray);
        NDList nDList3 = new NDList();
        Throwable th = null;
        try {
            Iterator it = nDList.iterator();
            while (it.hasNext()) {
                nDList3.add(((NDArray) it.next()).flatten());
            }
            NDArray concat = NDArrays.concat(nDList3);
            concat.attach(nDArray.getManager());
            nDList2.add(concat);
            if (nDList3 != null) {
                if (0 != 0) {
                    try {
                        nDList3.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    nDList3.close();
                }
            }
            nDList2.add(nDArray2);
            if (!z4) {
                return getManager().invoke("_npx_rnn", nDList2, mxOpParams);
            }
            NDList invoke = getManager().invoke("_npx_rnn", nDList2, mxOpParams);
            NDArray head = invoke.head();
            Throwable th3 = null;
            try {
                try {
                    NDList nDList4 = new NDList(new NDArray[]{head.swapAxes(0, 1), (NDArray) invoke.get(1)});
                    if (head != null) {
                        if (0 != 0) {
                            try {
                                head.close();
                            } catch (Throwable th4) {
                                th3.addSuppressed(th4);
                            }
                        } else {
                            head.close();
                        }
                    }
                    return nDList4;
                } finally {
                }
            } catch (Throwable th5) {
                if (head != null) {
                    if (th3 != null) {
                        try {
                            head.close();
                        } catch (Throwable th6) {
                            th3.addSuppressed(th6);
                        }
                    } else {
                        head.close();
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (nDList3 != null) {
                if (0 != 0) {
                    try {
                        nDList3.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    nDList3.close();
                }
            }
            throw th7;
        }
    }

    public NDList gru(NDArray nDArray, NDArray nDArray2, NDList nDList, boolean z, int i, double d, boolean z2, boolean z3, boolean z4) {
        int i2 = i * (z ? 4 : 2) * (z3 ? 2 : 1);
        Preconditions.checkArgument(nDList.size() == i2, "The size of Params is incorrect expect " + i2 + " parameters but got " + nDList.size());
        if (z2 != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of gru in MXNet should align with the mode of GradientCollector");
        }
        if (z4) {
            nDArray = nDArray.swapAxes(0, 1);
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("p", d);
        mxOpParams.addParam("state_size", nDArray2.getShape().tail());
        mxOpParams.addParam("num_layers", i);
        mxOpParams.addParam("bidirectional", z3);
        mxOpParams.addParam("state_outputs", true);
        mxOpParams.addParam("mode", "gru");
        NDList nDList2 = new NDList();
        nDList2.add(nDArray);
        NDList nDList3 = new NDList();
        Throwable th = null;
        try {
            try {
                Iterator it = nDList.iterator();
                while (it.hasNext()) {
                    nDList3.add(((NDArray) it.next()).flatten());
                }
                NDArray concat = NDArrays.concat(nDList3);
                concat.attach(nDArray.getManager());
                nDList2.add(concat);
                if (nDList3 != null) {
                    if (0 != 0) {
                        try {
                            nDList3.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        nDList3.close();
                    }
                }
                nDList2.add(nDArray2);
                if (!z4) {
                    return getManager().invoke("_npx_rnn", nDList2, mxOpParams);
                }
                NDList invoke = getManager().invoke("_npx_rnn", nDList2, mxOpParams);
                NDArray head = invoke.head();
                Throwable th3 = null;
                try {
                    try {
                        NDList nDList4 = new NDList(new NDArray[]{head.swapAxes(0, 1), (NDArray) invoke.get(1)});
                        if (head != null) {
                            if (0 != 0) {
                                try {
                                    head.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                head.close();
                            }
                        }
                        return nDList4;
                    } finally {
                    }
                } catch (Throwable th5) {
                    if (head != null) {
                        if (th3 != null) {
                            try {
                                head.close();
                            } catch (Throwable th6) {
                                th3.addSuppressed(th6);
                            }
                        } else {
                            head.close();
                        }
                    }
                    throw th5;
                }
            } finally {
            }
        } catch (Throwable th7) {
            if (nDList3 != null) {
                if (th != null) {
                    try {
                        nDList3.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    nDList3.close();
                }
            }
            throw th7;
        }
    }

    public NDList lstm(NDArray nDArray, NDList nDList, NDList nDList2, boolean z, int i, double d, boolean z2, boolean z3, boolean z4) {
        int i2 = i * (z ? 4 : 2) * (z3 ? 2 : 1);
        Preconditions.checkArgument(nDList2.size() == i2, "The size of Params is incorrect expect " + i2 + " parameters but got " + nDList2.size());
        if (z2 != JnaUtils.autogradIsTraining()) {
            throw new IllegalArgumentException("the mode of lstm in MXNet should align with the mode of GradientCollector");
        }
        if (z4) {
            nDArray = nDArray.swapAxes(0, 1);
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("mode", "lstm");
        mxOpParams.addParam("p", d);
        mxOpParams.addParam("state_size", nDList.head().getShape().tail());
        mxOpParams.addParam("state_outputs", true);
        mxOpParams.addParam("num_layers", i);
        mxOpParams.addParam("bidirectional", z3);
        mxOpParams.addParam("lstm_state_clip_nan", true);
        NDList nDList3 = new NDList();
        nDList3.add(nDArray);
        NDList nDList4 = new NDList();
        Throwable th = null;
        try {
            try {
                Iterator it = nDList2.iterator();
                while (it.hasNext()) {
                    nDList4.add(((NDArray) it.next()).flatten());
                }
                NDArray concat = NDArrays.concat(nDList4);
                concat.attach(nDArray.getManager());
                nDList3.add(concat);
                if (nDList4 != null) {
                    if (0 != 0) {
                        try {
                            nDList4.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        nDList4.close();
                    }
                }
                nDList3.addAll(nDList);
                if (!z4) {
                    return getManager().invoke("_npx_rnn", nDList3, mxOpParams);
                }
                NDList invoke = getManager().invoke("_npx_rnn", nDList3, mxOpParams);
                NDArray head = invoke.head();
                Throwable th3 = null;
                try {
                    try {
                        NDList nDList5 = new NDList(new NDArray[]{head.swapAxes(0, 1), (NDArray) invoke.get(1), (NDArray) invoke.get(2)});
                        if (head != null) {
                            if (0 != 0) {
                                try {
                                    head.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                head.close();
                            }
                        }
                        return nDList5;
                    } finally {
                    }
                } catch (Throwable th5) {
                    if (head != null) {
                        if (th3 != null) {
                            try {
                                head.close();
                            } catch (Throwable th6) {
                                th3.addSuppressed(th6);
                            }
                        } else {
                            head.close();
                        }
                    }
                    throw th5;
                }
            } finally {
            }
        } catch (Throwable th7) {
            if (nDList4 != null) {
                if (th != null) {
                    try {
                        nDList4.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    nDList4.close();
                }
            }
            throw th7;
        }
    }

    public NDArray normalize(float[] fArr, float[] fArr2) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addTupleParam("mean", fArr);
        mxOpParams.addTupleParam("std", fArr2);
        return getManager().invoke("_npx__image_normalize", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray toTensor() {
        return getManager().invoke("_npx__image_to_tensor", (NDArray) this.array, (PairList<String, ?>) null);
    }

    public NDArray resize(int i, int i2, int i3) {
        if (this.array.isEmpty()) {
            throw new IllegalArgumentException("attempt to resize of an empty NDArray");
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addTupleParam("size", i, i2);
        mxOpParams.addParam("interp", i3);
        return getManager().invoke("_npx__image_resize", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray crop(int i, int i2, int i3, int i4) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("x", Integer.valueOf(i));
        mxOpParams.add("y", Integer.valueOf(i2));
        mxOpParams.add("width", Integer.valueOf(i3));
        mxOpParams.add("height", Integer.valueOf(i4));
        return getManager().invoke("_npx__image_crop", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray randomFlipLeftRight() {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomFlipLeftRight is not supported on GPU");
        }
        return getManager().invoke("_npx__image_random_flip_left_right", (NDArray) this.array, (PairList<String, ?>) null);
    }

    public NDArray randomFlipTopBottom() {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomFlipTopBottom is not supported on GPU");
        }
        return getManager().invoke("_npx__image_random_flip_top_bottom", (NDArray) this.array, (PairList<String, ?>) null);
    }

    public NDArray randomBrightness(float f) {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomBrightness is not supported on GPU");
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("min_factor", Math.max(0.0f, 1.0f - f));
        mxOpParams.addParam("max_factor", 1.0f + f);
        return getManager().invoke("_npx__image_random_brightness", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray randomHue(float f) {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomHue is not supported on GPU");
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("min_factor", Math.max(0.0f, 1.0f - f));
        mxOpParams.addParam("max_factor", 1.0f + f);
        return getManager().invoke("_npx__image_random_hue", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArray randomColorJitter(float f, float f2, float f3, float f4) {
        if (this.array.getDevice().getDeviceType().equals("gpu")) {
            throw new UnsupportedOperationException("randomColorJitter is not supported on GPU");
        }
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("brightness", f);
        mxOpParams.addParam("contrast", f2);
        mxOpParams.addParam("saturation", f3);
        mxOpParams.addParam("hue", f4);
        return getManager().invoke("_npx__image_random_color_jitter", (NDArray) this.array, (PairList<String, ?>) mxOpParams);
    }

    public NDArrayIndexer getIndexer() {
        return new MxNDArrayIndexer(this.array.m4getManager());
    }

    public NDArray where(NDArray nDArray, NDArray nDArray2) {
        NDArray nDArray3;
        NDArray nDArray4;
        NDArray type = nDArray.getDataType() == DataType.BOOLEAN ? nDArray.toType(DataType.INT32, false) : nDArray;
        if (this.array.getDataType() != nDArray2.getDataType()) {
            throw new IllegalArgumentException("DataType mismatch, required " + this.array.getDataType() + " actual " + nDArray2.getDataType());
        }
        if (this.array.shapeEquals(nDArray2)) {
            nDArray3 = this.array;
            nDArray4 = nDArray2;
        } else {
            Shape deriveBroadcastedShape = deriveBroadcastedShape(this.array.getShape(), nDArray2.getShape());
            nDArray3 = !deriveBroadcastedShape.equals(this.array.getShape()) ? this.array.broadcast(deriveBroadcastedShape) : this.array;
            nDArray4 = !deriveBroadcastedShape.equals(nDArray2.getShape()) ? nDArray2.broadcast(deriveBroadcastedShape) : nDArray2;
        }
        try {
            NDArray invoke = getManager().invoke("where", new NDArray[]{type, nDArray3, nDArray4}, (PairList<String, ?>) null);
            if (nDArray3 != this.array) {
                nDArray3.close();
            }
            if (nDArray4 != nDArray2) {
                nDArray4.close();
            }
            return invoke;
        } catch (Throwable th) {
            if (nDArray3 != this.array) {
                nDArray3.close();
            }
            if (nDArray4 != nDArray2) {
                nDArray4.close();
            }
            throw th;
        }
    }

    public NDArray stack(NDList nDList, int i) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("axis", i);
        NDArray[] nDArrayArr = new NDArray[nDList.size() + 1];
        nDArrayArr[0] = this.array;
        System.arraycopy(nDList.toArray(new NDArray[0]), 0, nDArrayArr, 1, nDList.size());
        return getManager().invoke("_npi_stack", nDArrayArr, mxOpParams);
    }

    public NDArray concat(NDList nDList, int i) {
        NDUtils.checkConcatInput(nDList);
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.addParam("axis", i);
        NDArray[] nDArrayArr = new NDArray[nDList.size() + 1];
        nDArrayArr[0] = this.array;
        System.arraycopy(nDList.toArray(new NDArray[0]), 0, nDArrayArr, 1, nDList.size());
        return getManager().invoke("_npi_concatenate", nDArrayArr, mxOpParams);
    }

    public NDList multiBoxTarget(NDList nDList, float f, float f2, float f3, float f4, int i) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("minimum_negative_samples", Integer.valueOf(i));
        mxOpParams.add("overlap_threshold", Float.valueOf(f));
        mxOpParams.add("ignore_label", Float.valueOf(f2));
        mxOpParams.add("negative_mining_ratio", Float.valueOf(f3));
        mxOpParams.add("negative_mining_thresh", Float.valueOf(f4));
        return getManager().invoke("MultiBoxTarget", nDList, mxOpParams);
    }

    public NDList multiBoxPrior(List<Float> list, List<Float> list2, List<Float> list3, List<Float> list4, boolean z) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("sizes", list);
        mxOpParams.add("ratios", list2);
        mxOpParams.add("steps", list3);
        mxOpParams.add("offsets", list4);
        mxOpParams.add("clip", Boolean.valueOf(z));
        return getManager().invoke("MultiBoxPrior", new NDList(new NDArray[]{this.array}), mxOpParams);
    }

    public NDList multiBoxDetection(NDList nDList, boolean z, float f, int i, float f2, boolean z2, int i2) {
        MxOpParams mxOpParams = new MxOpParams();
        mxOpParams.add("clip", Boolean.valueOf(z));
        mxOpParams.add("threshold", Float.valueOf(f));
        mxOpParams.add("background_id", Integer.valueOf(i));
        mxOpParams.add("nms_threshold", Float.valueOf(f2));
        mxOpParams.add("force_suppress", Boolean.valueOf(z2));
        mxOpParams.add("nms_topk", Integer.valueOf(i2));
        return getManager().invoke("MultiBoxDetection", nDList, mxOpParams);
    }

    public NDArray getArray() {
        return this.array;
    }

    private MxNDManager getManager() {
        return this.array.m4getManager();
    }

    private int getGlobalPoolingDim() {
        int dimension = getArray().getShape().dimension() - 2;
        if (dimension < 1 || dimension > 3) {
            throw new IllegalStateException("GlobalPooling only support1 to 3 Dimensions, " + dimension + "D is not supported.");
        }
        return dimension;
    }

    private Shape getGlobalPoolingShapes(long j) {
        long[] jArr = new long[getGlobalPoolingDim()];
        Arrays.fill(jArr, j);
        return new Shape(jArr);
    }
}
