package ai.djl.mxnet.zoo.nlp.embedding;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding;
import ai.djl.mxnet.zoo.MxModelZoo;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.core.Embedding;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import ai.djl.util.Utils;
import java.io.IOException;
import java.util.Map;

/* loaded from: input_file:ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingModelLoader.class */
public class GloveWordEmbeddingModelLoader extends BaseModelLoader<NDList, NDList> {
    private static final Application APPLICATION = Application.NLP.WORD_EMBEDDING;
    private static final String GROUP_ID = "ai.djl.mxnet";
    private static final String ARTIFACT_ID = "glove";
    private static final String VERSION = "0.0.1";

    /* loaded from: input_file:ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingModelLoader$FactoryImpl.class */
    private static final class FactoryImpl implements TranslatorFactory<String, NDList> {
        private FactoryImpl() {
        }

        public Translator<String, NDList> newInstance(Model model, Map<String, Object> map) {
            return new TranslatorImpl((String) map.get("unknownToken"));
        }
    }

    /* loaded from: input_file:ai/djl/mxnet/zoo/nlp/embedding/GloveWordEmbeddingModelLoader$TranslatorImpl.class */
    private static final class TranslatorImpl implements Translator<String, NDList> {
        private String unknownToken;
        private Embedding<String> embedding;

        public TranslatorImpl(String str) {
            this.unknownToken = str;
        }

        public void prepare(NDManager nDManager, Model model) {
            try {
                this.embedding = model.getBlock();
            } catch (ClassCastException e) {
                throw new IllegalArgumentException("The model was not an embedding", e);
            }
        }

        /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
        public NDList m8processOutput(TranslatorContext translatorContext, NDList nDList) {
            return nDList;
        }

        public NDList processInput(TranslatorContext translatorContext, String str) {
            return this.embedding.hasItem(str) ? new NDList(new NDArray[]{translatorContext.getNDManager().create(this.embedding.embed(str))}) : new NDList(new NDArray[]{translatorContext.getNDManager().create(this.embedding.embed(this.unknownToken))});
        }

        public Batchifier getBatchifier() {
            return Batchifier.STACK;
        }
    }

    public GloveWordEmbeddingModelLoader(Repository repository) {
        super(repository, MRL.model(APPLICATION, "ai.djl.mxnet", ARTIFACT_ID), VERSION, new MxModelZoo());
        this.factories.put(new Pair(String.class, NDList.class), new FactoryImpl());
    }

    public Application getApplication() {
        return APPLICATION;
    }

    private Model customGloveBlock(Model model, Artifact artifact, Map<String, Object> map) throws IOException {
        model.setBlock(TrainableWordEmbedding.builder().setEmbeddingSize(Integer.parseInt((String) artifact.getProperties().get("dimensions"))).setVocabulary(new SimpleVocabulary(Utils.readLines(this.resource.getRepository().openStream((Artifact.Item) artifact.getFiles().get("idx_to_token"), (String) null)))).optUnknownToken((String) map.get("unknownToken")).optUseDefault(true).optSparseGrad(false).build());
        model.setProperty("unknownToken", (String) map.get("unknownToken"));
        return model;
    }

    protected Model createModel(String str, Device device, Artifact artifact, Map<String, Object> map, String str2) throws IOException {
        return customGloveBlock(Model.newInstance(str, device, str2), artifact, map);
    }

    public ZooModel<NDList, NDList> loadModel(Map<String, String> map, Device device, Progress progress) throws IOException, ModelNotFoundException, MalformedModelException {
        return loadModel(Criteria.builder().setTypes(NDList.class, NDList.class).optApplication(Application.NLP.WORD_EMBEDDING).optFilters(map).optDevice(device).optProgress(progress).build());
    }
}
