package ai.djl.tensorflow.engine;

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.nn.Block;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.RunOptions;

/* loaded from: input_file:ai/djl/tensorflow/engine/TfModel.class */
public class TfModel extends BaseModel {
    private AtomicBoolean first;
    private NDManager manager;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TfModel(String str, Device device) {
        super(str);
        Device defaultIfNull = Device.defaultIfNull(device);
        this.properties = new ConcurrentHashMap();
        this.manager = TfNDManager.getSystemManager().mo6newSubManager(defaultIfNull);
        this.first = new AtomicBoolean(true);
    }

    public void load(Path path, String str, Map<String, Object> map) throws FileNotFoundException {
        this.modelDir = path.toAbsolutePath();
        if (str == null) {
            str = this.modelName;
        }
        Path findModelDir = findModelDir(str);
        if (findModelDir == null) {
            findModelDir = findModelDir("saved_model.pb");
            if (findModelDir == null) {
                throw new FileNotFoundException("No TensorFlow model found in: " + this.modelDir);
            }
        }
        String[] strArr = null;
        ConfigProto configProto = null;
        RunOptions runOptions = null;
        String str2 = "serving_default";
        if (map != null) {
            strArr = (String[]) map.get("Tags");
            configProto = (ConfigProto) map.get("ConfigProto");
            runOptions = (RunOptions) map.get("RunOptions");
            str2 = (String) map.get("SignatureDefKey");
        }
        if (strArr == null) {
            strArr = new String[]{"serve"};
        }
        SavedModelBundle.Loader withTags = SavedModelBundle.loader(findModelDir.toString()).withTags(strArr);
        if (configProto != null) {
            withTags.withConfigProto(configProto);
        }
        if (runOptions != null) {
            withTags.withRunOptions(runOptions);
        }
        this.block = new TfSymbolBlock(withTags.load(), str2);
    }

    private Path findModelDir(String str) {
        Path resolve = this.modelDir.resolve(str);
        if (!Files.exists(resolve, new LinkOption[0])) {
            return null;
        }
        if (Files.isRegularFile(resolve, new LinkOption[0])) {
            return this.modelDir;
        }
        if (!Files.isDirectory(resolve, new LinkOption[0])) {
            return null;
        }
        Path resolve2 = resolve.resolve("saved_model.pb");
        if (Files.exists(resolve2, new LinkOption[0]) && Files.isRegularFile(resolve2, new LinkOption[0])) {
            return resolve;
        }
        return null;
    }

    public void save(Path path, String str) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Block getBlock() {
        return this.block;
    }

    public void setBlock(Block block) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
        return new Predictor<>(this, translator, this.first.getAndSet(false));
    }

    public NDManager getNDManager() {
        return this.manager;
    }

    public String[] getArtifactNames() {
        try {
            List<Path> list = (List) Files.walk(this.modelDir, new FileVisitOption[0]).filter(path -> {
                return Files.isRegularFile(path, new LinkOption[0]);
            }).collect(Collectors.toList());
            ArrayList arrayList = new ArrayList(list.size());
            for (Path path2 : list) {
                if (!path2.toFile().getName().endsWith(".pb")) {
                    arrayList.add(this.modelDir.relativize(path2).toString());
                }
            }
            return (String[]) arrayList.toArray(new String[0]);
        } catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public void close() {
        this.manager.close();
        if (this.block != null) {
            this.block.clear();
        }
    }
}
