package ai.djl.tensorflow.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.BlockList;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.proto.framework.TensorShapeProto;

/* loaded from: input_file:ai/djl/tensorflow/engine/TfSymbolBlock.class */
public class TfSymbolBlock implements SymbolBlock {
    private static final Logger logger = LoggerFactory.getLogger(TfSymbolBlock.class);
    private SavedModelBundle bundle;
    private Session session;
    private SignatureDef servingDefault;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;
    private ConcurrentHashMap<String, String> inputOutputNames = new ConcurrentHashMap<>();

    public TfSymbolBlock(SavedModelBundle savedModelBundle, String str) {
        this.bundle = savedModelBundle;
        this.session = savedModelBundle.session();
        Map signatureDefMap = savedModelBundle.metaGraphDef().getSignatureDefMap();
        if (signatureDefMap.containsKey(str)) {
            this.servingDefault = (SignatureDef) signatureDefMap.get(str);
        } else {
            Set keySet = signatureDefMap.keySet();
            logger.warn("SignatureDefKey: " + str + "not found in Saved Model Bundle.Available keys: " + String.join(" ", keySet) + "Please use .optOptions(\"SignatureDefKey\", \"value\") with Criteria.builder to load the model.Normally the value is \"default\" for TF1.x models and \"serving_default\" for TF2.x models. Refer to: https://www.tensorflow.org/guide/saved_modelLoading the model using next available key.");
            this.servingDefault = (SignatureDef) signatureDefMap.get(keySet.iterator().next());
        }
        describeInput();
        describeOutput();
    }

    public void removeLastBlock() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        Session.Runner runner = this.session.runner();
        for (int i = 0; i < this.inputDescriptions.size(); i++) {
            String str = (String) this.inputDescriptions.get(i).getKey();
            String str2 = this.inputOutputNames.get(str);
            NDArray nDArray = (NDArray) nDList.get(i);
            if (nDArray.getName().isEmpty()) {
                runner.feed(str2, ((TfNDArray) nDArray).getTensor());
            } else if (nDArray.getName().equals(str)) {
                runner.feed(str2, ((TfNDArray) nDArray).getTensor());
            } else {
                Iterator it = nDList.iterator();
                while (it.hasNext()) {
                    NDArray nDArray2 = (NDArray) it.next();
                    if (nDArray2.getName().equals(str)) {
                        runner.feed(str2, ((TfNDArray) nDArray2).getTensor());
                    }
                }
            }
        }
        for (int i2 = 0; i2 < this.outputDescriptions.size(); i2++) {
            runner.fetch(this.inputOutputNames.get((String) this.outputDescriptions.get(i2).getKey()));
        }
        List run = runner.run();
        TfNDManager manager = nDList.head().getManager();
        NDList nDList2 = new NDList();
        for (int i3 = 0; i3 < run.size(); i3++) {
            Tensor<?> tensor = (Tensor) run.get(i3);
            Throwable th = null;
            try {
                try {
                    TfNDArray create = manager.create(tensor);
                    create.setName((String) this.outputDescriptions.get(i3).getKey());
                    nDList2.add(create);
                    if (tensor != null) {
                        if (0 != 0) {
                            try {
                                tensor.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            tensor.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (tensor != null) {
                    if (th != null) {
                        try {
                            tensor.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        tensor.close();
                    }
                }
                throw th3;
            }
        }
        return nDList2;
    }

    public void setInitializer(Initializer initializer) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void setInitializer(Initializer initializer, String str) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape[] initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        return new Shape[0];
    }

    public boolean isInitialized() {
        return this.bundle != null;
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void clear() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.bundle != null) {
            this.bundle.close();
        }
    }

    public final PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            this.inputDescriptions = new PairList<>();
            Map inputsMap = this.servingDefault.getInputsMap();
            ArrayList<String> arrayList = new ArrayList(inputsMap.keySet());
            Collections.sort(arrayList);
            for (String str : arrayList) {
                TensorInfo tensorInfo = (TensorInfo) inputsMap.get(str);
                TensorShapeProto tensorShape = tensorInfo.getTensorShape();
                this.inputOutputNames.put(str, tensorInfo.getName());
                this.inputDescriptions.add(str, new Shape(tensorShape.getDimList().stream().mapToLong((v0) -> {
                    return v0.getSize();
                }).toArray()));
            }
        }
        return this.inputDescriptions;
    }

    public final PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            this.outputDescriptions = new PairList<>();
            Map outputsMap = this.servingDefault.getOutputsMap();
            ArrayList<String> arrayList = new ArrayList(outputsMap.keySet());
            Collections.sort(arrayList);
            for (String str : arrayList) {
                TensorInfo tensorInfo = (TensorInfo) outputsMap.get(str);
                TensorShapeProto tensorShape = tensorInfo.getTensorShape();
                if (tensorInfo.getDtype() != org.tensorflow.proto.framework.DataType.DT_STRING) {
                    this.inputOutputNames.put(str, tensorInfo.getName());
                    this.outputDescriptions.add(str, new Shape(tensorShape.getDimList().stream().mapToLong((v0) -> {
                        return v0.getSize();
                    }).toArray()));
                }
            }
        }
        return this.outputDescriptions;
    }

    public BlockList getChildren() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public ParameterList getDirectParameters() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public ParameterList getParameters() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape getParameterShape(String str, Shape[] shapeArr) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[0];
    }

    public void saveParameters(DataOutputStream dataOutputStream) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }
}
