package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDArrayIndexer;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexBooleans;
import ai.djl.ndarray.index.full.NDIndexFullPick;
import ai.djl.ndarray.index.full.NDIndexFullSlice;
import ai.djl.ndarray.index.full.NDIndexFullTake;
import ai.djl.ndarray.types.Shape;
import ai.djl.pytorch.jni.JniUtils;
import java.nio.Buffer;
import java.util.Iterator;
import java.util.Stack;

/* loaded from: input_file:ai/djl/pytorch/engine/PtNDArrayIndexer.class */
public class PtNDArrayIndexer extends NDArrayIndexer {
    private PtNDManager manager;

    /* JADX INFO: Access modifiers changed from: package-private */
    public PtNDArrayIndexer(PtNDManager ptNDManager) {
        this.manager = ptNDManager;
    }

    public NDArray get(NDArray nDArray, NDIndexFullPick nDIndexFullPick) {
        return JniUtils.pick(this.manager.mo180from(nDArray), this.manager.mo180from(nDIndexFullPick.getIndices()), nDIndexFullPick.getAxis());
    }

    public NDArray get(NDArray nDArray, NDIndexFullTake nDIndexFullTake) {
        return JniUtils.take(this.manager.mo180from(nDArray), this.manager.mo180from(nDIndexFullTake.getIndices()), this.manager);
    }

    public NDArray get(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice) {
        PtNDArray index = JniUtils.index(this.manager.mo180from(nDArray), nDIndexFullSlice.getMin(), nDIndexFullSlice.getMax(), nDIndexFullSlice.getStep(), this.manager);
        try {
            PtNDArray m47squeeze = index.m47squeeze(nDIndexFullSlice.getToSqueeze());
            if (index != null) {
                index.close();
            }
            return m47squeeze;
        } catch (Throwable th) {
            if (index != null) {
                try {
                    index.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public NDArray get(NDArray nDArray, NDIndex nDIndex) {
        if (nDIndex.getRank() == 0) {
            if (nDArray.getShape().isScalar()) {
                return nDArray.getManager() == this.manager ? nDArray.duplicate() : this.manager.mo179create((Buffer) nDArray.toByteBuffer(), nDArray.getShape(), nDArray.getDataType());
            }
            nDIndex.addAllDim();
        }
        return (nDArray == null || (nDArray instanceof PtNDArray)) ? JniUtils.indexAdv((PtNDArray) nDArray, nDIndex, this.manager) : JniUtils.indexAdv(this.manager.mo179create((Buffer) nDArray.toByteBuffer(), nDArray.getShape(), nDArray.getDataType()), nDIndex, this.manager);
    }

    public void set(NDArray nDArray, NDIndex nDIndex, Object obj) {
        PtNDArray mo179create = nDArray instanceof PtNDArray ? (PtNDArray) nDArray : this.manager.mo179create((Buffer) nDArray.toByteBuffer(), nDArray.getShape(), nDArray.getDataType());
        if (obj instanceof Number) {
            JniUtils.indexAdvPut(mo179create, nDIndex, (PtNDArray) this.manager.create((Number) obj));
        } else {
            if (!(obj instanceof NDArray)) {
                throw new IllegalArgumentException("The type of value to assign cannot be other than NDArray and Number.");
            }
            JniUtils.indexAdvPut(mo179create, nDIndex, this.manager.mo180from((NDArray) obj));
        }
    }

    public void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, NDArray nDArray2) {
        Shape shape;
        Stack stack = new Stack();
        stack.add(nDArray2);
        stack.add(((NDArray) stack.peek()).toDevice(nDArray.getDevice(), false));
        Shape shape2 = nDIndexFullSlice.getShape();
        while (true) {
            shape = shape2;
            if (shape.size() <= nDArray2.size()) {
                break;
            } else {
                shape2 = shape.slice(1);
            }
        }
        stack.add(((NDArray) stack.peek()).reshape(shape));
        stack.add(((NDArray) stack.peek()).broadcast(nDIndexFullSlice.getShape()));
        JniUtils.indexSet(this.manager.mo180from(nDArray), this.manager.mo180from((NDArray) stack.peek()), nDIndexFullSlice.getMin(), nDIndexFullSlice.getMax(), nDIndexFullSlice.getStep());
        Iterator it = stack.iterator();
        while (it.hasNext()) {
            NDArray nDArray3 = (NDArray) it.next();
            if (nDArray3 != nDArray2) {
                nDArray3.close();
            }
        }
    }

    public void set(NDArray nDArray, NDIndexBooleans nDIndexBooleans, NDArray nDArray2) {
        NDArray index = nDIndexBooleans.getIndex();
        try {
            JniUtils.booleanMaskSet(this.manager.mo180from(nDArray), this.manager.mo180from(nDArray2), this.manager.mo180from(index));
            if (index != null) {
                index.close();
            }
        } catch (Throwable th) {
            if (index != null) {
                try {
                    index.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void set(NDArray nDArray, NDIndexFullSlice nDIndexFullSlice, Number number) {
        set(nDArray, nDIndexFullSlice, nDArray.getManager().create(number));
    }
}
