package ai.djl.fasttext.zoo.nlp.word_embedding;

import ai.djl.Model;
import ai.djl.fasttext.FtAbstractBlock;
import ai.djl.fasttext.FtModel;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.WordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.ZooModel;

/* loaded from: input_file:ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.class */
public class FtWord2VecWordEmbedding implements WordEmbedding {
    private FtAbstractBlock embedding;
    private Vocabulary vocabulary;

    public FtWord2VecWordEmbedding(Model model, Vocabulary vocabulary) {
        model = model instanceof ZooModel ? ((ZooModel) model).getWrappedModel() : model;
        if (!(model instanceof FtModel)) {
            throw new IllegalArgumentException("The FtWord2VecWordEmbedding requires an FtModel");
        }
        this.embedding = ((FtModel) model).m0getBlock();
        this.vocabulary = vocabulary;
    }

    public FtWord2VecWordEmbedding(FtAbstractBlock ftAbstractBlock, Vocabulary vocabulary) {
        this.embedding = ftAbstractBlock;
        this.vocabulary = vocabulary;
    }

    public boolean vocabularyContains(String str) {
        return true;
    }

    public long preprocessWordToEmbed(String str) {
        return this.vocabulary.getIndex(str);
    }

    public NDArray embedWord(NDArray nDArray) {
        return embedWord(nDArray.getManager(), nDArray.toLongArray()[0]);
    }

    public NDArray embedWord(NDManager nDManager, long j) {
        return nDManager.create(this.embedding.embedWord(this.vocabulary.getToken(j)));
    }

    public String unembedWord(NDArray nDArray) {
        if (nDArray.isScalar()) {
            return this.vocabulary.getToken(nDArray.toLongArray()[0]);
        }
        throw new IllegalArgumentException("NDArray word must be scalar index");
    }
}
