package ai.djl.ndarray.internal;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.pooling.PoolingConvention;
import ai.djl.util.PairList;
import java.util.List;

/* loaded from: input_file:ai/djl/ndarray/internal/NDArrayEx.class */
public interface NDArrayEx {
    NDArray rdiv(Number number);

    NDArray rdiv(NDArray nDArray);

    NDArray rdivi(Number number);

    NDArray rdivi(NDArray nDArray);

    NDArray rsub(Number number);

    NDArray rsub(NDArray nDArray);

    NDArray rsubi(Number number);

    NDArray rsubi(NDArray nDArray);

    NDArray rmod(Number number);

    NDArray rmod(NDArray nDArray);

    NDArray rmodi(Number number);

    NDArray rmodi(NDArray nDArray);

    NDArray rpow(Number number);

    NDArray rpowi(Number number);

    NDArray relu();

    NDArray sigmoid();

    NDArray tanh();

    NDArray softrelu();

    NDArray softsign();

    NDArray leakyRelu(float f);

    NDArray elu(float f);

    NDArray selu();

    NDArray gelu();

    default NDArray swish(float f) {
        return Activation.sigmoid(getArray().mul(Float.valueOf(f))).mul(getArray());
    }

    default NDArray mish() {
        return getArray().exp().add((Number) 1).log2().tanh().mul(getArray());
    }

    NDArray maxPool(Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention);

    NDArray globalMaxPool();

    NDArray sumPool(Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention);

    NDArray globalSumPool();

    NDArray avgPool(Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention, boolean z);

    NDArray globalAvgPool();

    NDArray lpPool(Shape shape, Shape shape2, Shape shape3, PoolingConvention poolingConvention, int i);

    NDArray globalLpPool(int i);

    void adamUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, float f6, float f7, boolean z);

    void nagUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5);

    void sgdUpdate(NDList nDList, NDList nDList2, float f, float f2, float f3, float f4, float f5, boolean z);

    NDList convolution(NDList nDList, Shape shape, Shape shape2, Shape shape3, Shape shape4, int i, int i2, String str, boolean z, PairList<String, Object> pairList);

    NDList fullyConnected(NDList nDList, long j, boolean z, boolean z2, PairList<String, Object> pairList);

    NDList embedding(NDList nDList, int i, int i2, boolean z, DataType dataType, PairList<String, Object> pairList);

    NDList prelu(NDList nDList, PairList<String, Object> pairList);

    NDList dropout(NDList nDList, float f, int[] iArr, PairList<String, Object> pairList);

    NDList batchNorm(NDList nDList, float f, float f2, int i, boolean z, boolean z2, PairList<String, Object> pairList);

    NDList rnn(NDList nDList, String str, long j, float f, int i, boolean z, boolean z2, boolean z3, PairList<String, Object> pairList);

    NDList lstm(NDList nDList, long j, float f, int i, boolean z, boolean z2, boolean z3, double d, double d2, PairList<String, Object> pairList);

    default NDArray normalize(float[] fArr, float[] fArr2) {
        NDManager manager = getArray().getManager();
        Shape shape = getArray().getShape().dimension() == 3 ? new Shape(3, 1, 1) : new Shape(1, 3, 1, 1);
        NDArray create = manager.create(fArr, shape);
        Throwable th = null;
        try {
            NDArray create2 = manager.create(fArr2, shape);
            Throwable th2 = null;
            try {
                try {
                    NDArray divi = getArray().sub(create).divi(create2);
                    if (create2 != null) {
                        if (0 != 0) {
                            try {
                                create2.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            create2.close();
                        }
                    }
                    return divi;
                } finally {
                }
            } catch (Throwable th4) {
                if (create2 != null) {
                    if (th2 != null) {
                        try {
                            create2.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        create2.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (create != null) {
                if (0 != 0) {
                    try {
                        create.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    create.close();
                }
            }
        }
    }

    default NDArray toTensor() {
        NDArray array = getArray();
        int dimension = array.getShape().dimension();
        if (dimension == 3) {
            array = array.expandDims(0);
        }
        NDArray transpose = array.div(Double.valueOf(255.0d)).transpose(0, 3, 1, 2);
        if (dimension == 3) {
            transpose = transpose.squeeze(0);
        }
        return !transpose.getDataType().equals(DataType.FLOAT32) ? transpose.toType(DataType.FLOAT32, false) : transpose;
    }

    NDArray resize(int i, int i2);

    default NDArray crop(int i, int i2, int i3, int i4) {
        NDArray array = getArray();
        StringBuilder sb = new StringBuilder(30);
        if (array.getShape().dimension() == 4) {
            sb.append(":,");
        }
        sb.append(i2).append(':').append(i2 + i4).append(',').append(i).append(':').append(i + i3).append(",:");
        return array.get(sb.toString());
    }

    NDArray pick(NDArray nDArray, int i, boolean z, String str);

    default NDArray pick(NDArray nDArray, int i, boolean z) {
        return pick(nDArray, i, z, "clip");
    }

    NDArray where(NDArray nDArray, NDArray nDArray2);

    NDArray stack(NDList nDList, int i);

    default NDArray stack(NDList nDList) {
        return stack(nDList, 0);
    }

    NDArray concat(NDList nDList, int i);

    default NDArray concat(NDList nDList) {
        return concat(nDList, 0);
    }

    NDList multiBoxTarget(NDList nDList, float f, float f2, float f3, float f4, int i);

    NDList multiBoxPrior(List<Float> list, List<Float> list2, List<Float> list3, List<Float> list4, boolean z);

    NDList multiBoxDetection(NDList nDList, boolean z, float f, int i, float f2, boolean z2, int i2);

    NDArray getArray();
}
