package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.NativeResource;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import com.sun.jna.Pointer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/mxnet/engine/CachedOp.class */
public class CachedOp extends NativeResource {
    private static final Logger logger = LoggerFactory.getLogger(CachedOp.class);
    private List<Parameter> parameters;
    private MxNDArray[] debugInputs;
    private PairList<String, Integer> dataIndices;
    private Map<String, Integer> dataIndicesMap;
    private List<Integer> paramIndices;
    private MxNDManager manager;

    public CachedOp(Pointer pointer, MxNDManager mxNDManager, List<Parameter> list, List<Integer> list2, PairList<String, Integer> pairList) {
        super(pointer);
        this.parameters = list;
        this.dataIndices = pairList;
        this.paramIndices = list2;
        this.dataIndicesMap = pairList.toMap();
        this.manager = mxNDManager;
        mxNDManager.attach(getUid(), this);
    }

    public NDList forward(ParameterStore parameterStore, NDList nDList) {
        MxNDArray[] mxNDArrayArr = new MxNDArray[this.parameters.size()];
        this.debugInputs = mxNDArrayArr;
        Device device = nDList.head().getDevice();
        MxNDManager mxNDManager = (MxNDManager) nDList.head().getManager();
        Iterator<Integer> it = this.paramIndices.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            MxNDArray mxNDArray = (MxNDArray) parameterStore.getValue(this.parameters.get(intValue), device);
            if (mxNDArray == null) {
                throw new NullPointerException("Failed to find parameter from parameterStore");
            }
            mxNDArrayArr[intValue] = mxNDArray;
        }
        int i = 0;
        Iterator it2 = nDList.iterator();
        while (it2.hasNext()) {
            NDArray nDArray = (NDArray) it2.next();
            int i2 = i;
            i++;
            mxNDArrayArr[indexOf(nDArray.getName(), i2)] = (MxNDArray) nDArray;
        }
        Iterator it3 = this.dataIndices.iterator();
        while (it3.hasNext()) {
            Pair pair = (Pair) it3.next();
            if (mxNDArrayArr[((Integer) pair.getValue()).intValue()] == null) {
                long j = nDList.head().getShape().get(0);
                String str = (String) pair.getKey();
                if (!"prob_label".equals(str) && !"softmax_label".equals(str)) {
                    logger.warn("Input " + str + " not found, set NDArray to Shape(" + j + ") by default");
                }
                mxNDArrayArr[((Integer) pair.getValue()).intValue()] = (MxNDArray) mxNDManager.create(new Shape(new long[]{j}));
            }
        }
        return new NDList(JnaUtils.cachedOpInvoke(mxNDManager, getHandle(), mxNDArrayArr));
    }

    MxNDArray[] getInputNDArray() {
        return this.debugInputs;
    }

    @Override // ai.djl.mxnet.jna.NativeResource, java.lang.AutoCloseable
    public void close() {
        Pointer andSet = this.handle.getAndSet(null);
        if (andSet != null) {
            this.manager.detach(getUid());
            JnaUtils.freeCachedOp(andSet);
            this.manager = null;
        }
    }

    private int indexOf(String str, int i) {
        if (str == null) {
            return ((Integer) this.dataIndices.valueAt(i)).intValue();
        }
        Integer num = this.dataIndicesMap.get(str);
        if (num == null) {
            throw new IllegalArgumentException("Unknown input name: " + str + ", expected inputs: " + this.dataIndicesMap.keySet().toString());
        }
        return num.intValue();
    }
}
