package ai.djl.mxnet.engine;

import ai.djl.MalformedModelException;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ai/djl/mxnet/engine/MxSymbolBlock.class */
public class MxSymbolBlock extends AbstractBlock implements SymbolBlock {
    private static final byte VERSION = 2;
    private NDManager manager;
    private CachedOp op;
    private Symbol symbol;
    private List<Parameter> mxNetParams;
    private Map<String, Shape> paramShapes;
    private Shape[] outputShapes;

    public MxSymbolBlock(NDManager nDManager, Symbol symbol) {
        super((byte) 2);
        this.manager = nDManager;
        this.symbol = symbol;
        this.inputNames = new ArrayList();
        String[] allNames = symbol.getAllNames();
        this.mxNetParams = new ArrayList(allNames.length);
        HashSet hashSet = new HashSet(Arrays.asList(symbol.getAuxNames()));
        for (String str : allNames) {
            this.mxNetParams.add(new Parameter(str, this, inferType(str), !hashSet.contains(str)));
        }
    }

    public void setInputNames(List<String> list) {
        this.inputNames = list;
        HashSet hashSet = new HashSet(list);
        for (Parameter parameter : this.mxNetParams) {
            if (!hashSet.contains(parameter.getName())) {
                addParameter(parameter);
            }
        }
    }

    public List<Parameter> getAllParameters() {
        return this.mxNetParams;
    }

    public List<String> getLayerNames() {
        return this.symbol.getLayerNames();
    }

    public Symbol getSymbol() {
        return this.symbol;
    }

    public PairList<String, Shape> describeInput() {
        PairList<String, Shape> pairList = new PairList<>();
        Iterator it = this.inputNames.iterator();
        while (it.hasNext()) {
            pairList.add((String) it.next(), new Shape(new long[0]));
        }
        return pairList;
    }

    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        if (this.op == null) {
            this.op = JnaUtils.createCachedOp(this, this.manager);
        }
        return this.op.forward(parameterStore, nDList);
    }

    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        if (this.outputShapes == null) {
            String[] outputNames = this.symbol.getOutputNames();
            this.outputShapes = new Shape[outputNames.length];
            for (int i = 0; i < this.outputShapes.length; i++) {
                this.outputShapes[i] = getParameterShape(outputNames[i], shapeArr);
            }
        }
        return this.outputShapes;
    }

    public void removeLastBlock() {
        List<String> layerNames = getLayerNames();
        Symbol symbol = this.symbol.get(layerNames.get(layerNames.size() - VERSION));
        this.symbol.close();
        this.symbol = symbol;
        HashSet hashSet = new HashSet(Arrays.asList(this.symbol.getAllNames()));
        for (int size = this.mxNetParams.size() - 1; size >= 0; size--) {
            Parameter parameter = this.mxNetParams.get(size);
            if (!hashSet.contains(parameter.getName())) {
                this.mxNetParams.remove(size).close();
                this.parameters.remove(parameter.getName(), parameter);
            }
        }
    }

    public Shape getParameterShape(String str, Shape[] shapeArr) {
        if (this.paramShapes == null) {
            PairList<String, Shape> pairList = new PairList<>();
            for (int i = 0; i < this.inputNames.size(); i++) {
                pairList.add(this.inputNames.get(i), shapeArr[i]);
            }
            this.paramShapes = this.symbol.inferShape(pairList);
        }
        if (this.paramShapes.containsKey(str)) {
            return this.paramShapes.get(str);
        }
        throw new IllegalArgumentException("Name " + str + " not found");
    }

    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(VERSION);
        dataOutputStream.writeInt(this.inputNames.size());
        Iterator it = this.inputNames.iterator();
        while (it.hasNext()) {
            dataOutputStream.writeUTF((String) it.next());
        }
        for (Parameter parameter : this.parameters.values()) {
            if (!this.inputNames.contains(parameter.getName())) {
                parameter.save(dataOutputStream);
            }
        }
    }

    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte != VERSION) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        int readInt = dataInputStream.readInt();
        for (int i = 0; i < readInt; i++) {
            this.inputNames.add(dataInputStream.readUTF());
        }
        Iterator it = this.parameters.values().iterator();
        while (it.hasNext()) {
            ((Parameter) it.next()).load(this.manager, dataInputStream);
        }
    }

    private static ParameterType inferType(String str) {
        return str.endsWith("bias") ? ParameterType.BIAS : str.endsWith("gamma") ? ParameterType.GAMMA : str.endsWith("beta") ? ParameterType.BETA : (str.endsWith("moving_mean") || str.endsWith("running_mean")) ? ParameterType.RUNNING_MEAN : (str.endsWith("moving_var") || str.endsWith("running_var")) ? ParameterType.RUNNING_VAR : ParameterType.OTHER;
    }
}
