package ai.djl.modality.nlp;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/modality/nlp/SimpleVocabulary.class */
public class SimpleVocabulary implements Vocabulary {
    private Map<String, TokenInfo> tokens;
    private List<String> indexToToken;
    private Set<String> reservedTokens;
    private int minFrequency;
    private String unknownToken;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/modality/nlp/SimpleVocabulary$TokenInfo.class */
    public static final class TokenInfo {
        int frequency;
        long index = -1;
    }

    /* loaded from: input_file:ai/djl/modality/nlp/SimpleVocabulary$VocabularyBuilder.class */
    public static class VocabularyBuilder {
        protected List<List<String>> sentences = new LinkedList();
        protected Set<String> reservedTokens = new HashSet();
        protected int minFrequency = 10;
        protected String unknownToken = "<unk>";

        public VocabularyBuilder optMinFrequency(int i) {
            this.minFrequency = i;
            return this;
        }

        public VocabularyBuilder optUnknownToken(String str) {
            this.unknownToken = str;
            return this;
        }

        public VocabularyBuilder optReservedTokens(Collection<String> collection) {
            this.reservedTokens.addAll(collection);
            return this;
        }

        public VocabularyBuilder add(List<String> list) {
            this.sentences.add(list);
            return this;
        }

        public VocabularyBuilder addAll(List<List<String>> list) {
            this.sentences.addAll(list);
            return this;
        }

        public SimpleVocabulary build() {
            return new SimpleVocabulary(this);
        }
    }

    public SimpleVocabulary(VocabularyBuilder vocabularyBuilder) {
        this.tokens = new ConcurrentHashMap();
        this.indexToToken = new ArrayList();
        this.reservedTokens = vocabularyBuilder.reservedTokens;
        this.minFrequency = vocabularyBuilder.minFrequency;
        this.unknownToken = vocabularyBuilder.unknownToken;
        this.reservedTokens.add(this.unknownToken);
        addTokens(this.reservedTokens);
        Iterator<List<String>> it = vocabularyBuilder.sentences.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().iterator();
            while (it2.hasNext()) {
                addWord(it2.next());
            }
        }
    }

    public SimpleVocabulary(List<String> list) {
        this.tokens = new ConcurrentHashMap();
        this.indexToToken = new ArrayList();
        this.reservedTokens = new HashSet();
        this.minFrequency = 10;
        this.unknownToken = "<unk>";
        this.reservedTokens.add(this.unknownToken);
        addTokens(this.reservedTokens);
        addTokens(list);
    }

    private void addWord(String str) {
        if (this.reservedTokens.contains(str)) {
            return;
        }
        TokenInfo orDefault = this.tokens.getOrDefault(str, new TokenInfo());
        int i = orDefault.frequency + 1;
        orDefault.frequency = i;
        if (i == this.minFrequency) {
            orDefault.index = this.tokens.size();
            this.indexToToken.add(str);
        }
        this.tokens.put(str, orDefault);
    }

    private void addTokens(Collection<String> collection) {
        for (String str : collection) {
            TokenInfo tokenInfo = new TokenInfo();
            tokenInfo.frequency = Integer.MAX_VALUE;
            tokenInfo.index = this.indexToToken.size();
            this.indexToToken.add(str);
            this.tokens.put(str, tokenInfo);
        }
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public boolean contains(String str) {
        return this.tokens.containsKey(str);
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public String getToken(long j) {
        return (j < 0 || j >= ((long) this.indexToToken.size())) ? this.unknownToken : this.indexToToken.get((int) j);
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public long getIndex(String str) {
        if (this.tokens.containsKey(str)) {
            TokenInfo tokenInfo = this.tokens.get(str);
            if (tokenInfo.frequency >= this.minFrequency) {
                return tokenInfo.index;
            }
        }
        return this.tokens.get(this.unknownToken).index;
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public long size() {
        return this.tokens.size();
    }
}
