package ai.djl.repository.zoo;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDList;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.ServingTranslatorFactory;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/djl/repository/zoo/BaseModelLoader.class */
public class BaseModelLoader implements ModelLoader {
    protected Map<Pair<Type, Type>, TranslatorFactory<?, ?>> factories = new ConcurrentHashMap();
    protected ModelZoo modelZoo;
    protected Resource resource;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseModelLoader(Repository repository, MRL mrl, String str, ModelZoo modelZoo) {
        this.resource = new Resource(repository, mrl, str);
        this.modelZoo = modelZoo;
        this.factories.put(new Pair<>(NDList.class, NDList.class), (model, map) -> {
            return new NoopTranslator();
        });
        this.factories.put(new Pair<>(Input.class, Output.class), new ServingTranslatorFactory());
    }

    @Override // ai.djl.repository.zoo.ModelLoader
    public String getArtifactId() {
        return this.resource.getMrl().getArtifactId();
    }

    @Override // ai.djl.repository.zoo.ModelLoader
    public Application getApplication() {
        return this.resource.getMrl().getApplication();
    }

    @Override // ai.djl.repository.zoo.ModelLoader
    public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
        Artifact match = this.resource.match(criteria.getFilters());
        if (match == null) {
            throw new ModelNotFoundException("Model not found.");
        }
        Map<String, Object> arguments = criteria.getArguments();
        Progress progress = criteria.getProgress();
        Map<String, Object> arguments2 = match.getArguments(arguments);
        try {
            try {
                TranslatorFactory<I, O> translatorFactory = criteria.getTranslatorFactory();
                if (translatorFactory == null) {
                    translatorFactory = getTranslatorFactory(criteria);
                    if (translatorFactory == null) {
                        throw new ModelNotFoundException("No matching default translator found.");
                    }
                }
                this.resource.prepare(match, progress);
                if (progress != null) {
                    progress.reset("Loading", 2L);
                    progress.update(1L);
                }
                Path resourceDirectory = this.resource.getRepository().getResourceDirectory(match);
                String engine = criteria.getEngine();
                if (engine == null && this.modelZoo != null) {
                    String engineName = Engine.getInstance().getEngineName();
                    Iterator<String> it = this.modelZoo.getSupportedEngines().iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        String next = it.next();
                        if (next.equals(engineName)) {
                            engine = next;
                            break;
                        }
                        if (Engine.hasEngine(next)) {
                            engine = next;
                        }
                    }
                    if (engine == null) {
                        throw new ModelNotFoundException("No supported engine available for model zoo: " + this.modelZoo.getGroupId());
                    }
                }
                if (engine != null && !Engine.hasEngine(engine)) {
                    throw new ModelNotFoundException(engine + " is not supported.");
                }
                String modelName = criteria.getModelName();
                if (modelName == null) {
                    modelName = match.getName();
                }
                Model createModel = createModel(modelName, criteria.getDevice(), match, arguments2, engine);
                if (criteria.getBlock() != null) {
                    createModel.setBlock(criteria.getBlock());
                }
                createModel.load(resourceDirectory, null, criteria.getOptions());
                Application application = criteria.getApplication();
                if (application != Application.UNDEFINED) {
                    arguments2.put("application", application.getPath());
                }
                ZooModel<I, O> zooModel = new ZooModel<>(createModel, translatorFactory.newInstance(createModel, arguments2));
                if (progress != null) {
                    progress.end();
                }
                return zooModel;
            } catch (TranslateException e) {
                throw new ModelNotFoundException("No matching translator found.", e);
            }
        } catch (Throwable th) {
            if (progress != null) {
                progress.end();
            }
            throw th;
        }
    }

    @Override // ai.djl.repository.zoo.ModelLoader
    public List<Artifact> listModels() throws IOException {
        List<Artifact> listArtifacts = this.resource.listArtifacts();
        String version = this.resource.getVersion();
        return (List) listArtifacts.stream().filter(artifact -> {
            return version == null || version.equals(artifact.getVersion());
        }).collect(Collectors.toList());
    }

    protected Model createModel(String str, Device device, Artifact artifact, Map<String, Object> map, String str2) throws IOException {
        return Model.newInstance(str, device, str2);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(this.resource.getRepository().getName()).append(':').append(this.resource.getMrl().getGroupId()).append(':').append(this.resource.getMrl().getArtifactId()).append(" [\n");
        try {
            Iterator<Artifact> it = listModels().iterator();
            while (it.hasNext()) {
                sb.append('\t').append(it.next()).append('\n');
            }
        } catch (IOException e) {
            sb.append("\tFailed load metadata.");
        }
        sb.append("\n]");
        return sb.toString();
    }

    private <I, O> TranslatorFactory<I, O> getTranslatorFactory(Criteria<I, O> criteria) {
        if (criteria.getInputClass() == null) {
            throw new IllegalArgumentException("The criteria must set an input class");
        }
        if (criteria.getOutputClass() == null) {
            throw new IllegalArgumentException("The criteria must set an output class");
        }
        return (TranslatorFactory) this.factories.get(new Pair(criteria.getInputClass(), criteria.getOutputClass()));
    }
}
