package ai.djl.mxnet.engine;

import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.NativeResource;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import com.sun.jna.Pointer;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/djl/mxnet/engine/Symbol.class */
public class Symbol extends NativeResource {
    private String[] outputs;
    private MxNDManager manager;

    /* JADX INFO: Access modifiers changed from: package-private */
    public Symbol(MxNDManager mxNDManager, Pointer pointer) {
        super(pointer);
        this.manager = mxNDManager;
        mxNDManager.attach(getUid(), this);
    }

    public static Symbol load(MxNDManager mxNDManager, String str) {
        return new Symbol(mxNDManager, JnaUtils.createSymbolFromFile(str));
    }

    public String[] getArgNames() {
        return JnaUtils.listSymbolArguments(getHandle());
    }

    public String[] getAuxNames() {
        return JnaUtils.listSymbolAuxiliaryStates(getHandle());
    }

    public String[] getAllNames() {
        return JnaUtils.listSymbolNames(getHandle());
    }

    public String[] getOutputNames() {
        if (this.outputs == null) {
            this.outputs = JnaUtils.listSymbolOutputs(getHandle());
        }
        return this.outputs;
    }

    private String[] getInternalOutputNames() {
        return JnaUtils.listSymbolOutputs(getInternals().getHandle());
    }

    public Symbol copy() {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    public Symbol get(int i) {
        return new Symbol(this.manager, JnaUtils.getSymbolOutput(getInternals().getHandle(), i));
    }

    public Symbol get(String str) {
        int indexOf = Utils.indexOf(getInternalOutputNames(), str);
        if (indexOf < 0) {
            throw new IllegalArgumentException("Cannot find output that matches name: " + str);
        }
        return get(indexOf);
    }

    public Symbol getInternals() {
        return new Symbol(this.manager, JnaUtils.getSymbolInternals(getHandle()));
    }

    public List<String> getLayerNames() {
        String[] internalOutputNames = getInternalOutputNames();
        LinkedHashSet linkedHashSet = new LinkedHashSet(Arrays.asList(getAllNames()));
        return (List) Arrays.stream(internalOutputNames).filter(str -> {
            return !linkedHashSet.contains(str);
        }).collect(Collectors.toList());
    }

    public Map<String, Shape> inferShape(PairList<String, Shape> pairList) {
        List<List<Shape>> inferShape = JnaUtils.inferShape(this, pairList);
        if (inferShape == null) {
            throw new IllegalArgumentException("Cannot infer shape based on the data provided!");
        }
        List<Shape> list = inferShape.get(0);
        List<Shape> list2 = inferShape.get(1);
        List<Shape> list3 = inferShape.get(2);
        String[] argNames = getArgNames();
        String[] auxNames = getAuxNames();
        String[] outputNames = getOutputNames();
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        for (int i = 0; i < argNames.length; i++) {
            concurrentHashMap.put(argNames[i], list.get(i));
        }
        for (int i2 = 0; i2 < auxNames.length; i2++) {
            concurrentHashMap.put(auxNames[i2], list3.get(i2));
        }
        for (int i3 = 0; i3 < outputNames.length; i3++) {
            concurrentHashMap.put(outputNames[i3], list2.get(i3));
        }
        return concurrentHashMap;
    }

    public String toString() {
        return Arrays.toString(getOutputNames());
    }

    @Override // ai.djl.mxnet.jna.NativeResource, java.lang.AutoCloseable
    public void close() {
        Pointer andSet = this.handle.getAndSet(null);
        if (andSet != null) {
            this.manager.detach(getUid());
            JnaUtils.freeSymbol(andSet);
            this.manager = null;
        }
    }
}
