package ai.djl.modality.nlp.embedding;

import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.core.Embedding;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/djl/modality/nlp/embedding/TrainableWordEmbedding.class */
public class TrainableWordEmbedding extends Embedding<String> implements WordEmbedding {
    private static final String DEFAULT_UNKNOWN_TOKEN = "<unk>";

    /* loaded from: input_file:ai/djl/modality/nlp/embedding/TrainableWordEmbedding$Builder.class */
    public static class Builder extends Embedding.BaseBuilder<String, Builder> {
        Builder() {
            this.embeddingType = String.class;
            this.defaultItem = TrainableWordEmbedding.DEFAULT_UNKNOWN_TOKEN;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.nn.core.Embedding.BaseBuilder
        public Builder setType(Class<String> cls) {
            return self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.nn.core.Embedding.BaseBuilder
        public Builder self() {
            return this;
        }

        public Builder optUnknownToken(String str) {
            return optDefaultItem(str);
        }

        public TrainableWordEmbedding build() {
            return new TrainableWordEmbedding(this);
        }
    }

    public TrainableWordEmbedding(Builder builder) {
        super(builder);
    }

    public TrainableWordEmbedding(SimpleVocabulary simpleVocabulary, int i) {
        super(builder().setEmbeddingSize(i).setItems(simpleVocabulary.getAllTokens()).optSparseGrad(false).optDefaultItem(simpleVocabulary.getUnknownToken()).optUseDefault(false));
    }

    public TrainableWordEmbedding(NDArray nDArray, List<String> list) {
        super(nDArray, list);
        this.fallthroughEmbedding = new Embedding.DefaultItem(DEFAULT_UNKNOWN_TOKEN);
    }

    public TrainableWordEmbedding(NDArray nDArray, List<String> list, boolean z) {
        super(nDArray, list, z);
        this.fallthroughEmbedding = new Embedding.DefaultItem(DEFAULT_UNKNOWN_TOKEN);
    }

    @Override // ai.djl.modality.nlp.embedding.WordEmbedding
    public boolean vocabularyContains(String str) {
        return this.embedder.containsKey(str);
    }

    @Override // ai.djl.modality.nlp.embedding.WordEmbedding
    public int preprocessWordToEmbed(String str) {
        return embed(str);
    }

    @Override // ai.djl.modality.nlp.embedding.WordEmbedding
    public NDArray embedWord(NDManager nDManager, int i) {
        throw new UnsupportedOperationException("This operation is not supported by this class.");
    }

    @Override // ai.djl.modality.nlp.embedding.WordEmbedding
    public String unembedWord(NDArray nDArray) {
        if (!nDArray.isScalar()) {
            throw new IllegalArgumentException("NDArray word must be scalar index");
        }
        int i = nDArray.toIntArray()[0];
        Optional<String> unembed = unembed(i);
        if (unembed.isPresent()) {
            return unembed.get();
        }
        Optional unembed2 = this.fallthroughEmbedding.unembed(i);
        if (unembed2.isPresent()) {
            return (String) unembed2.get();
        }
        throw new IllegalArgumentException("Failed to unembed word");
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public byte[] encode(String str) {
        return str.getBytes(StandardCharsets.UTF_8);
    }

    @Override // ai.djl.nn.core.AbstractIndexedEmbedding
    public String decode(byte[] bArr) {
        return new String(bArr, StandardCharsets.UTF_8);
    }

    public static Builder builder() {
        return new Builder();
    }
}
