package ai.djl.tensorflow.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.tensorflow.engine.javacpp.JavacppUtils;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.internal.c_api.TFE_TensorHandle;
import org.tensorflow.internal.c_api.TF_Graph;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Session;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.proto.SignatureDef;
import org.tensorflow.proto.TensorInfo;

/* loaded from: input_file:ai/djl/tensorflow/engine/TfSymbolBlock.class */
public class TfSymbolBlock extends AbstractSymbolBlock implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(TfSymbolBlock.class);
    private SavedModelBundle bundle;
    private TF_Graph graphHandle;
    private TF_Session sessionHandle;
    private SignatureDef servingDefault;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;
    private TF_Operation[] inputOpHandles;
    private int[] inputOpIndices;
    private TF_Operation[] outputOpHandles;
    private int[] outputOpIndices;
    private TF_Operation[] targetOpHandles;

    public TfSymbolBlock(SavedModelBundle savedModelBundle, String str) {
        this.bundle = savedModelBundle;
        this.graphHandle = savedModelBundle.getGraph();
        this.sessionHandle = savedModelBundle.getSession();
        this.targetOpHandles = savedModelBundle.getTargetOpHandles();
        Map signatureDefMap = savedModelBundle.getMetaGraphDef().getSignatureDefMap();
        if (signatureDefMap.containsKey(str)) {
            this.servingDefault = (SignatureDef) signatureDefMap.get(str);
        } else {
            Set keySet = signatureDefMap.keySet();
            logger.warn("SignatureDefKey: {} not found in Saved Model Bundle.Available keys: {} Please use .optOption(\"SignatureDefKey\", \"value\") with Criteria.builder to load the model.Normally the value is \"default\" for TF1.x models and \"serving_default\" for TF2.x models. Refer to: https://www.tensorflow.org/guide/saved_modelLoading the model using next available key.", str, String.join(" ", keySet));
            this.servingDefault = (SignatureDef) signatureDefMap.get(keySet.iterator().next());
        }
        describeInput();
        describeOutput();
    }

    public void removeLastBlock() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        TF_Tensor[] tF_TensorArr = new TF_Tensor[this.inputDescriptions.size()];
        for (int i = 0; i < this.inputDescriptions.size(); i++) {
            String str = (String) this.inputDescriptions.get(i).getKey();
            TfNDArray tfNDArray = (TfNDArray) nDList.get(i);
            String name = tfNDArray.getName();
            if (name == null || name.isEmpty() || name.equals(str)) {
                tF_TensorArr[i] = JavacppUtils.resolveTFETensor((TFE_TensorHandle) tfNDArray.getHandle());
            } else {
                Iterator it = nDList.iterator();
                while (it.hasNext()) {
                    NDArray nDArray = (NDArray) it.next();
                    if (nDArray.getName().equals(str)) {
                        tF_TensorArr[i] = JavacppUtils.resolveTFETensor((TFE_TensorHandle) ((TfNDArray) nDArray).getHandle());
                    }
                }
            }
        }
        TF_Tensor[] runSession = JavacppUtils.runSession(this.sessionHandle, null, tF_TensorArr, this.inputOpHandles, this.inputOpIndices, this.outputOpHandles, this.outputOpIndices, this.targetOpHandles);
        TfNDManager manager = nDList.head().getManager();
        NDList nDList2 = new NDList();
        for (int i2 = 0; i2 < runSession.length; i2++) {
            TfNDArray tfNDArray2 = new TfNDArray(manager, JavacppUtils.createTFETensor(runSession[i2]));
            tfNDArray2.setName((String) this.outputDescriptions.get(i2).getKey());
            nDList2.add(tfNDArray2);
        }
        Arrays.stream(tF_TensorArr).forEach((v0) -> {
            v0.close();
        });
        Arrays.stream(runSession).forEach((v0) -> {
            v0.close();
        });
        return nDList2;
    }

    public void initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        throw new IllegalStateException("TfSymbolBlock can't be initialized");
    }

    public boolean isInitialized() {
        return this.bundle != null;
    }

    public final PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            this.inputDescriptions = new PairList<>();
            Map inputsMap = this.servingDefault.getInputsMap();
            ArrayList arrayList = new ArrayList(inputsMap.keySet());
            Collections.sort(arrayList);
            this.inputOpHandles = new TF_Operation[arrayList.size()];
            this.inputOpIndices = new int[arrayList.size()];
            for (int i = 0; i < arrayList.size(); i++) {
                TensorInfo tensorInfo = (TensorInfo) inputsMap.get(arrayList.get(i));
                this.inputDescriptions.add((String) arrayList.get(i), new Shape(tensorInfo.getTensorShape().getDimList().stream().mapToLong((v0) -> {
                    return v0.getSize();
                }).toArray()));
                Pair<TF_Operation, Integer> graphOperationByName = JavacppUtils.getGraphOperationByName(this.graphHandle, tensorInfo.getName());
                this.inputOpHandles[i] = (TF_Operation) graphOperationByName.getKey();
                this.inputOpIndices[i] = ((Integer) graphOperationByName.getValue()).intValue();
            }
        }
        return this.inputDescriptions;
    }

    public ParameterList getDirectParameters() {
        throw new UnsupportedOperationException("Not yet supported");
    }

    public final PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            this.outputDescriptions = new PairList<>();
            Map outputsMap = this.servingDefault.getOutputsMap();
            ArrayList<String> arrayList = new ArrayList(outputsMap.keySet());
            Collections.sort(arrayList);
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            for (String str : arrayList) {
                TensorInfo tensorInfo = (TensorInfo) outputsMap.get(str);
                this.outputDescriptions.add(str, new Shape(tensorInfo.getTensorShape().getDimList().stream().mapToLong((v0) -> {
                    return v0.getSize();
                }).toArray()));
                Pair<TF_Operation, Integer> graphOperationByName = JavacppUtils.getGraphOperationByName(this.graphHandle, tensorInfo.getName());
                arrayList2.add((TF_Operation) graphOperationByName.getKey());
                arrayList3.add((Integer) graphOperationByName.getValue());
            }
            this.outputOpHandles = (TF_Operation[]) arrayList2.toArray(new TF_Operation[0]);
            this.outputOpIndices = arrayList3.stream().mapToInt(num -> {
                return num.intValue();
            }).toArray();
        }
        return this.outputDescriptions;
    }

    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[0];
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.bundle != null) {
            this.bundle.close();
        }
        Arrays.stream(this.inputOpHandles).forEach((v0) -> {
            v0.close();
        });
        Arrays.stream(this.outputOpHandles).forEach((v0) -> {
            v0.close();
        });
        Arrays.stream(this.targetOpHandles).forEach((v0) -> {
            v0.close();
        });
    }
}
