package ai.djl.fasttext;

import ai.djl.modality.nlp.embedding.WordEmbedding;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import com.github.jfasttext.FastTextWrapper;

/* loaded from: input_file:ai/djl/fasttext/FtWord2VecWordEmbedding.class */
public class FtWord2VecWordEmbedding implements WordEmbedding {
    private FtModel model;
    private FtVocabulary vocabulary;

    public FtWord2VecWordEmbedding(FtModel ftModel, FtVocabulary ftVocabulary) {
        this.model = ftModel;
        this.vocabulary = ftVocabulary;
    }

    public boolean vocabularyContains(String str) {
        return true;
    }

    public long preprocessWordToEmbed(String str) {
        return this.vocabulary.getIndex(str);
    }

    public NDArray embedWord(NDArray nDArray) {
        return embedWord(nDArray.getManager(), nDArray.toLongArray()[0]);
    }

    public NDArray embedWord(NDManager nDManager, long j) {
        FastTextWrapper.RealVector vector = this.model.fta.getVector(this.vocabulary.getToken(j));
        int size = (int) vector.size();
        float[] fArr = new float[size];
        for (int i = 0; i < size; i++) {
            fArr[i] = vector.get(i);
        }
        return nDManager.create(fArr);
    }

    public String unembedWord(NDArray nDArray) {
        if (nDArray.isScalar()) {
            return this.vocabulary.getToken(nDArray.toLongArray()[0]);
        }
        throw new IllegalArgumentException("NDArray word must be scalar index");
    }
}
