package ai.djl.pytorch.jni;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/pytorch/jni/IValueUtils.class */
public final class IValueUtils {
    private IValueUtils() {
    }

    public static Pointer toIValuePointer(PtNDArray ptNDArray) {
        return PyTorchLibrary.LIB.iValueCreateFromTensor(ptNDArray.getHandle());
    }

    public static boolean isNDArray(Pointer pointer) {
        return PyTorchLibrary.LIB.iValueIsTensor(pointer);
    }

    public static boolean isNDList(Pointer pointer) {
        return PyTorchLibrary.LIB.iValueIsTensorList(pointer);
    }

    public static boolean isList(Pointer pointer) {
        return PyTorchLibrary.LIB.iValueIsList(pointer);
    }

    public static boolean isTuple(Pointer pointer) {
        return PyTorchLibrary.LIB.iValueIsTuple(pointer);
    }

    public static boolean isMap(Pointer pointer) {
        return PyTorchLibrary.LIB.iValueIsMap(pointer);
    }

    public static boolean isString(Pointer pointer) {
        return PyTorchLibrary.LIB.iValueIsString(pointer);
    }

    public static PtNDArray toNDArray(Pointer pointer, PtNDManager ptNDManager) {
        return ptNDManager.create(PyTorchLibrary.LIB.iValueToTensor(pointer));
    }

    public static NDList toNDList(Pointer pointer, PtNDManager ptNDManager) {
        Pointer[] iValueToTensorList = PyTorchLibrary.LIB.iValueToTensorList(pointer);
        NDList nDList = new NDList();
        for (Pointer pointer2 : iValueToTensorList) {
            nDList.add(ptNDManager.create(pointer2));
        }
        return nDList;
    }

    public static String toString(Pointer pointer) {
        return PyTorchLibrary.LIB.iValueToString(pointer);
    }

    public static Pointer[] toIValueArray(Pointer pointer) {
        return isTuple(pointer) ? PyTorchLibrary.LIB.iValueToListFromTuple(pointer) : PyTorchLibrary.LIB.iValueToList(pointer);
    }

    public static Map<Pointer, Pointer> toIValueMap(Pointer pointer) {
        Pointer[] iValueToMap = PyTorchLibrary.LIB.iValueToMap(pointer);
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        for (int i = 0; i < iValueToMap.length; i += 2) {
            concurrentHashMap.put(iValueToMap[i], iValueToMap[i + 1]);
        }
        return concurrentHashMap;
    }

    private static NDList forwardHelper(Pointer pointer, PtNDManager ptNDManager) {
        NDList nDList = new NDList();
        if (isNDArray(pointer)) {
            nDList.add(toNDArray(pointer, ptNDManager));
        } else if (isNDList(pointer)) {
            nDList.addAll(toNDList(pointer, ptNDManager));
        } else if (isList(pointer) || isTuple(pointer)) {
            for (Pointer pointer2 : toIValueArray(pointer)) {
                nDList.addAll(forwardHelper(pointer2, ptNDManager));
            }
        } else {
            if (!isMap(pointer)) {
                PyTorchLibrary.LIB.torchDeleteIValue(pointer);
                throw new UnsupportedOperationException("Unsupported IValue type");
            }
            for (Map.Entry<Pointer, Pointer> entry : toIValueMap(pointer).entrySet()) {
                String iValueUtils = toString(entry.getKey());
                PyTorchLibrary.LIB.torchDeleteIValue(entry.getKey());
                PtNDArray nDArray = toNDArray(entry.getValue(), ptNDManager);
                PyTorchLibrary.LIB.torchDeleteIValue(entry.getValue());
                nDArray.setName(iValueUtils);
                nDList.add(nDArray);
            }
        }
        PyTorchLibrary.LIB.torchDeleteIValue(pointer);
        return nDList;
    }

    public static NDList forward(PtSymbolBlock ptSymbolBlock, NDList nDList, boolean z) {
        return forwardHelper(PyTorchLibrary.LIB.moduleForward(ptSymbolBlock.getHandle(), (Pointer[]) nDList.stream().map(nDArray -> {
            return ((PtNDArray) nDArray).getHandle();
        }).toArray(i -> {
            return new Pointer[i];
        }), z), ((NDArray) nDList.get(0)).getManager());
    }
}
