package ai.djl.basicdataset.utils;

import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.EmbeddingException;
import ai.djl.modality.nlp.embedding.TextEmbedding;
import ai.djl.modality.nlp.embedding.TrainableTextEmbedding;
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding;
import ai.djl.modality.nlp.preprocess.LowerCaseConvertor;
import ai.djl.modality.nlp.preprocess.PunctuationSeparator;
import ai.djl.modality.nlp.preprocess.SimpleTokenizer;
import ai.djl.modality.nlp.preprocess.TextProcessor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.AbstractBlock;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;

/* loaded from: input_file:ai/djl/basicdataset/utils/TextData.class */
public class TextData {
    private List<NDArray> textEmbeddingList;
    private List<String> rawText;
    private List<TextProcessor> textProcessors;
    private List<String> reservedTokens;
    private TextEmbedding textEmbedding;
    private Vocabulary vocabulary;
    private String unknownToken;
    private int embeddingSize;
    private int size;

    /* loaded from: input_file:ai/djl/basicdataset/utils/TextData$Configuration.class */
    public static final class Configuration {
        private List<TextProcessor> textProcessors;
        private TextEmbedding textEmbedding;
        private Vocabulary vocabulary;
        private Integer embeddingSize;
        private String unknownToken;
        private List<String> reservedTokens;

        public Configuration setTextProcessors(List<TextProcessor> list) {
            this.textProcessors = list;
            return this;
        }

        public Configuration setTextEmbedding(TextEmbedding textEmbedding) {
            this.textEmbedding = textEmbedding;
            return this;
        }

        public Configuration setVocabulary(Vocabulary vocabulary) {
            this.vocabulary = vocabulary;
            return this;
        }

        public Configuration setEmbeddingSize(int i) {
            this.embeddingSize = Integer.valueOf(i);
            return this;
        }

        public Configuration setUnknownToken(String str) {
            this.unknownToken = str;
            return this;
        }

        public Configuration setReservedTokens(List<String> list) {
            this.reservedTokens = list;
            return this;
        }

        public Configuration update(Configuration configuration) {
            this.textProcessors = configuration.textProcessors != null ? configuration.textProcessors : this.textProcessors;
            this.textEmbedding = configuration.textEmbedding != null ? configuration.textEmbedding : this.textEmbedding;
            this.vocabulary = configuration.vocabulary != null ? configuration.vocabulary : this.vocabulary;
            this.embeddingSize = configuration.embeddingSize != null ? configuration.embeddingSize : this.embeddingSize;
            this.unknownToken = configuration.unknownToken != null ? configuration.unknownToken : this.unknownToken;
            this.reservedTokens = configuration.reservedTokens != null ? configuration.reservedTokens : this.reservedTokens;
            return this;
        }
    }

    public TextData(Configuration configuration) {
        this.textProcessors = configuration.textProcessors;
        this.textEmbedding = configuration.textEmbedding;
        this.vocabulary = configuration.vocabulary;
        this.embeddingSize = configuration.embeddingSize.intValue();
        this.unknownToken = configuration.unknownToken;
        this.reservedTokens = configuration.reservedTokens;
    }

    public static Configuration getDefaultConfiguration() {
        return new Configuration().setEmbeddingSize(15).setTextProcessors(Arrays.asList(new SimpleTokenizer(), new LowerCaseConvertor(Locale.ENGLISH), new PunctuationSeparator())).setUnknownToken("<unk>").setReservedTokens(Arrays.asList("<bos>", "<eos>", "<pad>"));
    }

    public void preprocess(NDManager nDManager, List<String> list) throws EmbeddingException {
        this.rawText = list;
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            List singletonList = Collections.singletonList(it.next());
            Iterator<TextProcessor> it2 = this.textProcessors.iterator();
            while (it2.hasNext()) {
                singletonList = it2.next().preprocess(singletonList);
            }
            arrayList.add(singletonList);
        }
        if (this.vocabulary == null) {
            DefaultVocabulary.Builder builder = DefaultVocabulary.builder();
            builder.optMinFrequency(3).optReservedTokens(this.reservedTokens).optUnknownToken(this.unknownToken);
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                builder.add((List) it3.next());
            }
            this.vocabulary = builder.build();
        }
        if (this.textEmbedding == null) {
            this.textEmbedding = new TrainableTextEmbedding(new TrainableWordEmbedding(this.vocabulary, this.embeddingSize));
        }
        this.size = arrayList.size();
        this.textEmbeddingList = new ArrayList();
        for (int i = 0; i < this.size; i++) {
            List list2 = (List) arrayList.get(i);
            for (int i2 = 0; i2 < list2.size(); i2++) {
                list2.set(i2, this.vocabulary.getToken(this.vocabulary.getIndex((String) list2.get(i2))));
            }
            arrayList.set(i, list2);
            if (this.textEmbedding instanceof AbstractBlock) {
                this.textEmbeddingList.add(nDManager.create(this.textEmbedding.preprocessTextToEmbed(list2)));
            } else {
                this.textEmbeddingList.add(this.textEmbedding.embedText(nDManager, list2));
            }
        }
    }

    public void setTextProcessors(List<TextProcessor> list) {
        this.textProcessors = list;
    }

    public void setTextEmbedding(TextEmbedding textEmbedding) {
        this.textEmbedding = textEmbedding;
    }

    public TextEmbedding getTextEmbedding() {
        return this.textEmbedding;
    }

    public void setEmbeddingSize(int i) {
        this.embeddingSize = i;
    }

    public Vocabulary getVocabulary() {
        if (this.vocabulary == null) {
            throw new IllegalStateException("This method must be called after preprocess is called on this object");
        }
        return this.vocabulary;
    }

    public NDArray getEmbedding(NDManager nDManager, long j) {
        NDArray duplicate = this.textEmbeddingList.get(Math.toIntExact(j)).duplicate();
        duplicate.attach(nDManager);
        return duplicate;
    }

    public String getRawText(long j) {
        return this.rawText.get(Math.toIntExact(j));
    }

    public List<String> getProcessedText(long j) {
        List<String> singletonList = Collections.singletonList(getRawText(j));
        Iterator<TextProcessor> it = this.textProcessors.iterator();
        while (it.hasNext()) {
            singletonList = it.next().preprocess(singletonList);
        }
        return singletonList;
    }

    public int getSize() {
        return this.size;
    }
}
