package org.deeplearning4j.models.word2vec.wordstore;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.util.ThreadUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threadly.concurrent.PriorityScheduler;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.class */
public class VocabConstructor<T extends SequenceElement> {
    private List<VocabSource<T>> sources;
    private VocabCache<T> cache;
    private Collection<String> stopWords;
    private boolean useAdaGrad;
    private boolean fetchLabels;
    private int limit;
    private AtomicLong seqCount;
    private InvertedIndex<T> index;
    private boolean enableScavenger;
    private T unk;
    private boolean allowParallelBuilder;
    protected static final Logger log = LoggerFactory.getLogger(VocabConstructor.class);

    /* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabConstructor$Builder.class */
    public static class Builder<T extends SequenceElement> {
        private VocabCache<T> cache;
        private InvertedIndex<T> index;
        private int limit;
        private T unk;
        private List<VocabSource<T>> sources = new ArrayList();
        private Collection<String> stopWords = new ArrayList();
        private boolean useAdaGrad = false;
        private boolean fetchLabels = false;
        private boolean enableScavenger = false;
        private boolean allowParallelBuilder = true;

        public Builder<T> setEntriesLimit(int i) {
            this.limit = i;
            return this;
        }

        public Builder<T> allowParallelTokenization(boolean z) {
            this.allowParallelBuilder = z;
            return this;
        }

        protected Builder<T> useAdaGrad(boolean z) {
            this.useAdaGrad = z;
            return this;
        }

        public Builder<T> setTargetVocabCache(@NonNull VocabCache<T> vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("cache is marked @NonNull but is null");
            }
            this.cache = vocabCache;
            return this;
        }

        public Builder<T> addSource(@NonNull SequenceIterator<T> sequenceIterator, int i) {
            if (sequenceIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            this.sources.add(new VocabSource<>(sequenceIterator, i));
            return this;
        }

        public Builder<T> setStopWords(@NonNull Collection<String> collection) {
            if (collection == null) {
                throw new NullPointerException("stopWords is marked @NonNull but is null");
            }
            this.stopWords = collection;
            return this;
        }

        public Builder<T> fetchLabels(boolean z) {
            this.fetchLabels = z;
            return this;
        }

        public Builder<T> setIndex(InvertedIndex<T> invertedIndex) {
            this.index = invertedIndex;
            return this;
        }

        public Builder<T> enableScavenger(boolean z) {
            this.enableScavenger = z;
            return this;
        }

        public Builder<T> setUnk(T t) {
            this.unk = t;
            return this;
        }

        public VocabConstructor<T> build() {
            VocabConstructor<T> vocabConstructor = new VocabConstructor<>();
            ((VocabConstructor) vocabConstructor).sources = this.sources;
            ((VocabConstructor) vocabConstructor).cache = this.cache;
            ((VocabConstructor) vocabConstructor).stopWords = this.stopWords;
            ((VocabConstructor) vocabConstructor).useAdaGrad = this.useAdaGrad;
            ((VocabConstructor) vocabConstructor).fetchLabels = this.fetchLabels;
            ((VocabConstructor) vocabConstructor).limit = this.limit;
            ((VocabConstructor) vocabConstructor).index = this.index;
            ((VocabConstructor) vocabConstructor).enableScavenger = this.enableScavenger;
            ((VocabConstructor) vocabConstructor).unk = this.unk;
            ((VocabConstructor) vocabConstructor).allowParallelBuilder = this.allowParallelBuilder;
            return vocabConstructor;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabConstructor$VocabRunnable.class */
    public class VocabRunnable implements Runnable {
        private final AtomicLong finalCounter;
        private final Sequence<T> document;
        private final AbstractCache<T> targetVocab;
        private final AtomicLong loopCounter;
        private boolean done;

        public VocabRunnable(@NonNull AbstractCache<T> abstractCache, @NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong, @NonNull AtomicLong atomicLong2) {
            if (abstractCache == null) {
                throw new NullPointerException("targetVocab is marked @NonNull but is null");
            }
            if (sequence == null) {
                throw new NullPointerException("sequence is marked @NonNull but is null");
            }
            if (atomicLong == null) {
                throw new NullPointerException("finalCounter is marked @NonNull but is null");
            }
            if (atomicLong2 == null) {
                throw new NullPointerException("loopCounter is marked @NonNull but is null");
            }
            this.finalCounter = atomicLong;
            this.document = sequence;
            this.targetVocab = abstractCache;
            this.loopCounter = atomicLong2;
        }

        public void awaitDone() throws InterruptedException {
            synchronized (this) {
                while (!this.done) {
                    wait();
                }
            }
        }

        @Override // java.lang.Runnable
        public void run() {
            try {
                try {
                    HashMap hashMap = new HashMap();
                    if (VocabConstructor.this.fetchLabels && this.document.getSequenceLabels() != null) {
                        for (T t : this.document.getSequenceLabels()) {
                            if (!this.targetVocab.hasToken(t.getLabel())) {
                                t.setSpecial(true);
                                t.markAsLabel(true);
                                t.setElementFrequency(1L);
                                this.targetVocab.addToken(t);
                            }
                        }
                    }
                    for (String str : this.document.asLabels()) {
                        if (VocabConstructor.this.stopWords == null || !VocabConstructor.this.stopWords.contains(str)) {
                            if (str != null && !str.isEmpty()) {
                                if (this.targetVocab.containsWord(str)) {
                                    this.targetVocab.incrementWordCount(str);
                                    if (!hashMap.containsKey(str)) {
                                        hashMap.put(str, new AtomicLong(1L));
                                        this.targetVocab.wordFor(str).incrementSequencesCount();
                                    }
                                    if (VocabConstructor.this.index != null) {
                                        if (this.document.getSequenceLabel() != null) {
                                            VocabConstructor.this.index.addWordsToDoc(VocabConstructor.this.index.numDocuments(), (List<List<T>>) this.document.getElements(), (List<T>) this.document.getSequenceLabel());
                                        } else {
                                            VocabConstructor.this.index.addWordsToDoc(VocabConstructor.this.index.numDocuments(), this.document.getElements());
                                        }
                                    }
                                } else {
                                    T elementByLabel = this.document.getElementByLabel(str);
                                    elementByLabel.setElementFrequency(1L);
                                    elementByLabel.setSequencesCount(1L);
                                    this.targetVocab.addToken(elementByLabel);
                                    this.loopCounter.incrementAndGet();
                                    hashMap.put(str, new AtomicLong(0L));
                                }
                            }
                        }
                    }
                    this.finalCounter.incrementAndGet();
                    synchronized (this) {
                        notifyAll();
                    }
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            } catch (Throwable th) {
                this.finalCounter.incrementAndGet();
                synchronized (this) {
                    notifyAll();
                    throw th;
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/VocabConstructor$VocabSource.class */
    public static class VocabSource<T extends SequenceElement> {

        @NonNull
        private SequenceIterator<T> iterator;

        @NonNull
        private int minWordFrequency;

        public VocabSource(@NonNull SequenceIterator<T> sequenceIterator, @NonNull int i) {
            if (sequenceIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            this.iterator = sequenceIterator;
            this.minWordFrequency = i;
        }

        @NonNull
        public SequenceIterator<T> getIterator() {
            return this.iterator;
        }

        @NonNull
        public int getMinWordFrequency() {
            return this.minWordFrequency;
        }

        public void setIterator(@NonNull SequenceIterator<T> sequenceIterator) {
            if (sequenceIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            this.iterator = sequenceIterator;
        }

        public void setMinWordFrequency(@NonNull int i) {
            this.minWordFrequency = i;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VocabSource)) {
                return false;
            }
            VocabSource vocabSource = (VocabSource) obj;
            if (!vocabSource.canEqual(this)) {
                return false;
            }
            SequenceIterator<T> iterator = getIterator();
            SequenceIterator<T> iterator2 = vocabSource.getIterator();
            if (iterator == null) {
                if (iterator2 != null) {
                    return false;
                }
            } else if (!iterator.equals(iterator2)) {
                return false;
            }
            return getMinWordFrequency() == vocabSource.getMinWordFrequency();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof VocabSource;
        }

        public int hashCode() {
            SequenceIterator<T> iterator = getIterator();
            return (((1 * 59) + (iterator == null ? 43 : iterator.hashCode())) * 59) + getMinWordFrequency();
        }

        public String toString() {
            return "VocabConstructor.VocabSource(iterator=" + getIterator() + ", minWordFrequency=" + getMinWordFrequency() + ")";
        }
    }

    private VocabConstructor() {
        this.sources = new ArrayList();
        this.useAdaGrad = false;
        this.fetchLabels = false;
        this.seqCount = new AtomicLong(0L);
        this.enableScavenger = false;
        this.allowParallelBuilder = true;
    }

    protected WeightLookupTable<T> buildExtendedLookupTable() {
        return null;
    }

    protected VocabCache<T> buildExtendedVocabulary() {
        return null;
    }

    public VocabCache<T> buildMergedVocabulary(@NonNull WordVectors wordVectors, boolean z) {
        if (wordVectors == null) {
            throw new NullPointerException("wordVectors is marked @NonNull but is null");
        }
        return buildMergedVocabulary(wordVectors.vocab(), z);
    }

    public long getNumberOfSequences() {
        return this.seqCount.get();
    }

    public VocabCache<T> buildMergedVocabulary(@NonNull VocabCache<T> vocabCache, boolean z) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache is marked @NonNull but is null");
        }
        if (this.cache == null) {
            this.cache = new AbstractCache.Builder().build();
        }
        for (int i = 0; i < vocabCache.numWords(); i++) {
            String wordAtIndex = vocabCache.wordAtIndex(i);
            if (wordAtIndex != null) {
                T wordFor = vocabCache.wordFor(wordAtIndex);
                if (z || !wordFor.isLabel()) {
                    this.cache.addToken(wordFor);
                    this.cache.addWordToIndex(wordFor.getIndex(), wordFor.getLabel());
                    this.cache.putVocabWord(wordFor.getLabel());
                }
            }
        }
        if (this.cache.numWords() == 0) {
            throw new IllegalStateException("Source VocabCache has no indexes available, transfer is impossible");
        }
        log.info("Vocab size before labels: " + this.cache.numWords());
        if (z) {
            Iterator<VocabSource<T>> it = this.sources.iterator();
            while (it.hasNext()) {
                SequenceIterator<T> iterator = it.next().getIterator();
                iterator.reset();
                while (iterator.hasMoreSequences()) {
                    Sequence<T> nextSequence = iterator.nextSequence();
                    this.seqCount.incrementAndGet();
                    if (nextSequence.getSequenceLabels() != null) {
                        for (T t : nextSequence.getSequenceLabels()) {
                            if (!this.cache.containsWord(t.getLabel())) {
                                t.markAsLabel(true);
                                t.setSpecial(true);
                                t.setIndex(this.cache.numWords());
                                this.cache.addToken(t);
                                this.cache.addWordToIndex(t.getIndex(), t.getLabel());
                                this.cache.putVocabWord(t.getLabel());
                            }
                        }
                    }
                }
            }
        }
        log.info("Vocab size after labels: " + this.cache.numWords());
        return this.cache;
    }

    public VocabCache<T> transferVocabulary(@NonNull VocabCache<T> vocabCache, boolean z) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache is marked @NonNull but is null");
        }
        VocabCache<T> build = this.cache != null ? this.cache : new AbstractCache.Builder().build();
        for (T t : vocabCache.tokens()) {
            build.addToken(t);
            if (t.getIndex() >= 0) {
                build.addWordToIndex(t.getIndex(), t.getLabel());
            } else {
                build.addWordToIndex(build.numWords(), t.getLabel());
            }
        }
        if (z) {
            Huffman huffman = new Huffman(build.vocabWords());
            huffman.build();
            huffman.applyIndexes(build);
        }
        return build;
    }

    public VocabCache<T> buildJointVocabulary(boolean z, boolean z2) {
        long currentTimeMillis = System.currentTimeMillis();
        long j = 0;
        long j2 = 0;
        AtomicLong atomicLong = new AtomicLong(0L);
        if (z && z2) {
            throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
        }
        if (this.cache == null) {
            this.cache = new AbstractCache.Builder().build();
        }
        log.debug("Target vocab size before building: [" + this.cache.numWords() + "]");
        AtomicLong atomicLong2 = new AtomicLong(0L);
        AbstractCache<T> build = new AbstractCache.Builder().minElementFrequency(0).build();
        int i = 0;
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        PriorityScheduler priorityScheduler = new PriorityScheduler(Math.max(availableProcessors / 2, 2));
        AtomicLong atomicLong3 = new AtomicLong(0L);
        AtomicLong atomicLong4 = new AtomicLong(0L);
        for (VocabSource<T> vocabSource : this.sources) {
            SequenceIterator<T> iterator = vocabSource.getIterator();
            iterator.reset();
            log.debug("Trying source iterator: [" + i + "]");
            log.debug("Target vocab size before building: [" + this.cache.numWords() + "]");
            i++;
            AbstractCache<T> build2 = new AbstractCache.Builder().build();
            new ArrayList();
            new ArrayList();
            int i2 = 0;
            while (iterator.hasMoreSequences()) {
                Sequence<T> nextSequence = iterator.nextSequence();
                this.seqCount.incrementAndGet();
                atomicLong.addAndGet(nextSequence.size());
                build2.incrementTotalDocCount();
                atomicLong3.incrementAndGet();
                VocabRunnable vocabRunnable = new VocabRunnable(build2, nextSequence, atomicLong4, atomicLong2);
                priorityScheduler.execute(vocabRunnable);
                if (!this.allowParallelBuilder) {
                    try {
                        vocabRunnable.awaitDone();
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        throw new RuntimeException(e);
                    }
                }
                while (atomicLong3.get() - atomicLong4.get() > availableProcessors) {
                    ThreadUtils.uncheckedSleep(1L);
                }
                i2++;
                if (this.seqCount.get() % 100000 == 0) {
                    long currentTimeMillis2 = System.currentTimeMillis();
                    long j3 = this.seqCount.get();
                    long j4 = atomicLong.get();
                    double d = (currentTimeMillis2 - currentTimeMillis) / 1000.0d;
                    log.info("Sequences checked: [{}]; Current vocabulary size: [{}]; Sequences/sec: {}; Words/sec: {};", new Object[]{Long.valueOf(this.seqCount.get()), Integer.valueOf(build2.numWords()), String.format("%.2f", Double.valueOf((j3 - j) / d)), String.format("%.2f", Double.valueOf((j4 - j2) / d))});
                    currentTimeMillis = currentTimeMillis2;
                    j2 = j4;
                    j = j3;
                }
                if (this.enableScavenger && atomicLong2.get() >= 2000000 && build2.numWords() > 10000000) {
                    log.info("Starting scavenger...");
                    while (atomicLong3.get() != atomicLong4.get()) {
                        ThreadUtils.uncheckedSleep(1L);
                    }
                    filterVocab(build2, Math.max(1, vocabSource.getMinWordFrequency() / 2));
                    atomicLong2.set(0L);
                }
            }
            log.debug("Waiting till all processes stop...");
            while (atomicLong3.get() != atomicLong4.get()) {
                ThreadUtils.uncheckedSleep(1L);
            }
            log.debug("Vocab size before truncation: [" + build2.numWords() + "],  NumWords: [" + build2.totalWordOccurrences() + "], sequences parsed: [" + this.seqCount.get() + "], counter: [" + atomicLong.get() + "]");
            if (vocabSource.getMinWordFrequency() > 0) {
                filterVocab(build2, vocabSource.getMinWordFrequency());
            }
            log.debug("Vocab size after truncation: [" + build2.numWords() + "],  NumWords: [" + build2.totalWordOccurrences() + "], sequences parsed: [" + this.seqCount.get() + "], counter: [" + atomicLong.get() + "]");
            build.importVocabulary(build2);
        }
        System.gc();
        this.cache.importVocabulary(build);
        if (this.unk != null) {
            log.info("Adding UNK element to vocab...");
            this.unk.setSpecial(true);
            this.cache.addToken(this.unk);
        }
        if (z) {
            Iterator<T> it = this.cache.vocabWords().iterator();
            while (it.hasNext()) {
                it.next().setElementFrequency(0L);
            }
            this.cache.updateWordsOccurrences();
        }
        if (z2) {
            if (this.limit > 0) {
                ArrayList arrayList = new ArrayList(this.cache.vocabWords());
                Collections.sort(arrayList);
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    SequenceElement sequenceElement = (SequenceElement) it2.next();
                    if (sequenceElement.getIndex() > this.limit && !sequenceElement.isSpecial() && !sequenceElement.isLabel()) {
                        this.cache.removeElement(sequenceElement.getLabel());
                    }
                }
            }
            Huffman huffman = new Huffman(this.cache.vocabWords());
            huffman.build();
            huffman.applyIndexes(this.cache);
        }
        priorityScheduler.shutdown();
        System.gc();
        log.info("Sequences checked: [{}], Current vocabulary size: [{}]; Sequences/sec: [{}];", new Object[]{Long.valueOf(this.seqCount.get()), Integer.valueOf(this.cache.numWords()), String.format("%.2f", Double.valueOf(this.seqCount.get() / ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d)))});
        return this.cache;
    }

    protected void filterVocab(AbstractCache<T> abstractCache, int i) {
        int numWords = abstractCache.numWords();
        LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue();
        for (T t : abstractCache.vocabWords()) {
            if (t.getElementFrequency() < i && !t.isSpecial() && !t.isLabel()) {
                linkedBlockingQueue.add(t.getLabel());
            }
        }
        Iterator it = linkedBlockingQueue.iterator();
        while (it.hasNext()) {
            abstractCache.removeElement((String) it.next());
        }
        log.debug("Scavenger: Words before: {}; Words after: {};", Integer.valueOf(numWords), Integer.valueOf(abstractCache.numWords()));
    }
}
