package org.deeplearning4j.models.glove;

import akka.actor.ActorSystem;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TfidfVectorizer;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.glove.CoOccurrences;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.text.invertedindex.LuceneInvertedIndex;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/glove/Glove.class */
public class Glove extends WordVectorsImpl {
    private transient SentenceIterator sentenceIterator;
    private transient TextVectorizer textVectorizer;
    private transient TokenizerFactory tokenizerFactory;
    private double learningRate;
    private double xMax;
    private int windowSize;
    private CoOccurrences coOccurrences;
    private boolean stem;
    protected Queue<Pair<Integer, List<Pair<VocabWord, VocabWord>>>> jobQueue;
    private int batchSize;
    private int minWordFrequency;
    private double maxCount;
    public static final String UNK = "UNK";
    private int iterations;
    private static final Logger log = LoggerFactory.getLogger(Glove.class);
    private boolean symmetric;
    private transient Random gen;
    private boolean shuffle;
    private transient Random shuffleRandom;
    private int numWorkers;

    /* loaded from: input_file:org/deeplearning4j/models/glove/Glove$Builder.class */
    public static class Builder {
        private VocabCache vocabCache;
        private SentenceIterator sentenceIterator;
        private TextVectorizer textVectorizer;
        private GloveWeightLookupTable weightLookupTable;
        private CoOccurrences coOccurrences;
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private int layerSize = 300;
        private double learningRate = 0.05d;
        private double xMax = 0.75d;
        private int windowSize = 5;
        private List<String> stopWords = StopWords.getStopWords();
        private boolean stem = false;
        private int batchSize = 100;
        private int minWordFrequency = 5;
        private double maxCount = 100.0d;
        private int iterations = 5;
        private boolean symmetric = true;
        private boolean shuffle = true;
        private long seed = 123;
        private int numWorkers = Runtime.getRuntime().availableProcessors();
        private Random gen = Nd4j.getRandom();

        public Builder numWorkers(int i) {
            this.numWorkers = i;
            return this;
        }

        public Builder seed(long j) {
            this.seed = j;
            return this;
        }

        public Builder shuffle(boolean z) {
            this.shuffle = z;
            return this;
        }

        public Builder rng(Random random) {
            this.gen = random;
            return this;
        }

        public Builder symmetric(boolean z) {
            this.symmetric = z;
            return this;
        }

        public Builder iterations(int i) {
            this.iterations = i;
            return this;
        }

        public Builder maxCount(double d) {
            this.maxCount = d;
            return this;
        }

        public Builder minWordFrequency(int i) {
            this.minWordFrequency = i;
            return this;
        }

        public Builder cache(VocabCache vocabCache) {
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder iterate(SentenceIterator sentenceIterator) {
            this.sentenceIterator = sentenceIterator;
            return this;
        }

        public Builder vectorizer(TextVectorizer textVectorizer) {
            this.textVectorizer = textVectorizer;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder weights(GloveWeightLookupTable gloveWeightLookupTable) {
            this.weightLookupTable = gloveWeightLookupTable;
            return this;
        }

        public Builder layerSize(int i) {
            this.layerSize = i;
            return this;
        }

        public Builder learningRate(double d) {
            this.learningRate = d;
            return this;
        }

        public Builder xMax(double d) {
            this.xMax = d;
            return this;
        }

        public Builder windowSize(int i) {
            this.windowSize = i;
            return this;
        }

        public Builder coOccurrences(CoOccurrences coOccurrences) {
            this.coOccurrences = coOccurrences;
            return this;
        }

        public Builder stopWords(List<String> list) {
            this.stopWords = list;
            return this;
        }

        public Builder stem(boolean z) {
            this.stem = z;
            return this;
        }

        public Builder batchSize(int i) {
            this.batchSize = i;
            return this;
        }

        public Glove build() {
            return new Glove(this.vocabCache, this.sentenceIterator, this.textVectorizer, this.tokenizerFactory, this.weightLookupTable, this.layerSize, this.learningRate, this.xMax, this.windowSize, this.coOccurrences, this.stopWords, this.stem, this.batchSize, this.minWordFrequency, this.maxCount, this.iterations, this.symmetric, this.gen, this.shuffle, this.seed, this.numWorkers);
        }
    }

    private Glove() {
        this.learningRate = 0.05d;
        this.xMax = 0.75d;
        this.windowSize = 15;
        this.stem = false;
        this.jobQueue = new LinkedBlockingDeque();
        this.batchSize = 1000;
        this.minWordFrequency = 5;
        this.maxCount = 100.0d;
        this.iterations = 5;
        this.symmetric = true;
        this.shuffle = true;
        this.numWorkers = Runtime.getRuntime().availableProcessors();
    }

    public Glove(VocabCache vocabCache, SentenceIterator sentenceIterator, TextVectorizer textVectorizer, TokenizerFactory tokenizerFactory, GloveWeightLookupTable gloveWeightLookupTable, int i, double d, double d2, int i2, CoOccurrences coOccurrences, List<String> list, boolean z, int i3, int i4, double d3, int i5, boolean z2, Random random, boolean z3, long j, int i6) {
        this.learningRate = 0.05d;
        this.xMax = 0.75d;
        this.windowSize = 15;
        this.stem = false;
        this.jobQueue = new LinkedBlockingDeque();
        this.batchSize = 1000;
        this.minWordFrequency = 5;
        this.maxCount = 100.0d;
        this.iterations = 5;
        this.symmetric = true;
        this.shuffle = true;
        this.numWorkers = Runtime.getRuntime().availableProcessors();
        this.numWorkers = i6;
        this.gen = random;
        this.vocab = vocabCache;
        this.layerSize = i;
        this.shuffle = z3;
        this.sentenceIterator = sentenceIterator;
        this.textVectorizer = textVectorizer;
        this.tokenizerFactory = tokenizerFactory;
        this.lookupTable = gloveWeightLookupTable;
        this.learningRate = d;
        this.xMax = d2;
        this.windowSize = i2;
        this.coOccurrences = coOccurrences;
        this.stopWords = list;
        this.stem = z;
        this.batchSize = i3;
        this.minWordFrequency = i4;
        this.maxCount = d3;
        this.iterations = i5;
        this.symmetric = z2;
        this.shuffleRandom = Nd4j.getRandom();
    }

    public void fit() {
        boolean z = false;
        if (vocab() == null) {
            z = true;
            setVocab(new InMemoryLookupCache());
        }
        if (this.textVectorizer == null && z) {
            this.textVectorizer = new TfidfVectorizer.Builder().tokenize(this.tokenizerFactory).index(new LuceneInvertedIndex(vocab(), false, "glove-index")).cache(vocab()).iterate(this.sentenceIterator).minWords(this.minWordFrequency).stopWords(this.stopWords).stem(this.stem).build();
            this.textVectorizer.fit();
        }
        if (this.sentenceIterator != null) {
            this.sentenceIterator.reset();
        }
        if (this.coOccurrences == null) {
            this.coOccurrences = new CoOccurrences.Builder().cache(vocab()).iterate(this.sentenceIterator).symmetric(this.symmetric).tokenizer(this.tokenizerFactory).windowSize(this.windowSize).build();
            this.coOccurrences.fit();
        }
        if (this.lookupTable == null) {
            this.lookupTable = new GloveWeightLookupTable.Builder().cache(this.textVectorizer.vocab()).lr(this.learningRate).vectorLength(this.layerSize).maxCount(this.maxCount).build();
        }
        if (lookupTable().getSyn0() == null) {
            lookupTable().resetWeights();
        }
        List<Pair<String, String>> coOccurrenceList = this.coOccurrences.coOccurrenceList();
        if (this.shuffle) {
            Collections.shuffle(coOccurrenceList, new java.util.Random());
        }
        AtomicInteger atomicInteger = new AtomicInteger(0);
        Counter<Integer> parallelCounter = Util.parallelCounter();
        log.info("Processing # of co occurrences " + this.coOccurrences.numCoOccurrences());
        for (int i = 0; i < this.iterations; i++) {
            doIteration(i, coOccurrenceList, parallelCounter, new AtomicInteger(this.coOccurrences.numCoOccurrences()), atomicInteger);
            log.info("Processed " + atomicInteger.doubleValue() + " out of " + (coOccurrenceList.size() * this.iterations) + " error was " + parallelCounter.getCount(Integer.valueOf(i)));
        }
    }

    public void doIteration(final int i, List<Pair<String, String>> list, final Counter<Integer> counter, final AtomicInteger atomicInteger, final AtomicInteger atomicInteger2) {
        log.info("Iteration " + i);
        if (this.shuffle) {
            Collections.shuffle(list, new java.util.Random());
        }
        List partition = Lists.partition(list, this.batchSize);
        ActorSystem create = ActorSystem.create();
        Parallelization.iterateInParallel(partition, new Parallelization.RunnableWithParams<List<Pair<String, String>>>() { // from class: org.deeplearning4j.models.glove.Glove.1
            public void run(List<Pair<String, String>> list2, Object[] objArr) {
                ArrayList arrayList = new ArrayList();
                for (Pair<String, String> pair : list2) {
                    arrayList.add(new Pair(Glove.this.vocab().wordFor((String) pair.getFirst()), Glove.this.vocab().wordFor((String) pair.getSecond())));
                }
                Glove.this.jobQueue.add(new Pair<>(Integer.valueOf(i), arrayList));
            }
        }, create);
        create.shutdown();
        Parallelization.runInParallel(this.numWorkers, new Runnable() { // from class: org.deeplearning4j.models.glove.Glove.2
            @Override // java.lang.Runnable
            public void run() {
                while (true) {
                    if (atomicInteger.get() <= 0 && Glove.this.jobQueue.isEmpty()) {
                        return;
                    }
                    Pair<Integer, List<Pair<VocabWord, VocabWord>>> poll = Glove.this.jobQueue.poll();
                    if (poll != null) {
                        for (Pair pair : (List) poll.getSecond()) {
                            VocabWord vocabWord = (VocabWord) pair.getFirst();
                            VocabWord vocabWord2 = (VocabWord) pair.getSecond();
                            double count = Glove.this.getCount(vocabWord.getWord(), vocabWord2.getWord());
                            if (count <= 0.0d) {
                                atomicInteger2.incrementAndGet();
                                atomicInteger.decrementAndGet();
                            } else {
                                counter.incrementCount(poll.getFirst(), Glove.this.lookupTable().iterateSample(vocabWord, vocabWord2, count));
                                atomicInteger2.incrementAndGet();
                                if (atomicInteger2.get() % 10000 == 0) {
                                    Glove.log.info("Processed " + atomicInteger2.get() + " co occurrences");
                                }
                                atomicInteger.decrementAndGet();
                            }
                        }
                    }
                }
            }
        }, true);
    }

    public static Glove load(InputStream inputStream, InputStream inputStream2) throws IOException {
        LineIterator lineIterator = IOUtils.lineIterator(inputStream, "UTF-8");
        Glove glove = new Glove();
        HashMap hashMap = new HashMap();
        int i = 0;
        while (lineIterator.hasNext()) {
            String trim = lineIterator.nextLine().trim();
            if (!trim.isEmpty()) {
                String[] split = trim.split(" ");
                String str = split[0];
                if (glove.vocab() == null) {
                    glove.setVocab(new InMemoryLookupCache());
                }
                if (glove.lookupTable() == null) {
                    glove.lookupTable = new GloveWeightLookupTable.Builder().cache(glove.vocab()).vectorLength(split.length - 1).build();
                }
                if (!str.isEmpty()) {
                    float[] read = read(split, glove.lookupTable().getVectorLength());
                    if (read.length >= 1) {
                        VocabWord vocabWord = new VocabWord(1.0d, str);
                        vocabWord.setIndex(i);
                        glove.vocab().addToken(vocabWord);
                        glove.vocab().addWordToIndex(i, str);
                        glove.vocab().putVocabWord(str);
                        hashMap.put(str, read);
                        i++;
                    }
                }
            }
        }
        glove.lookupTable().setSyn0(weights(glove, hashMap));
        lineIterator.close();
        glove.lookupTable().setBias(Nd4j.readTxt(inputStream2, " "));
        return glove;
    }

    private static INDArray weights(Glove glove, Map<String, float[]> map) {
        INDArray create = Nd4j.create(map.size(), glove.lookupTable().getVectorLength());
        for (String str : map.keySet()) {
            INDArray create2 = Nd4j.create(Nd4j.createBuffer(map.get(str)));
            if (create2.length() == glove.lookupTable().getVectorLength() && glove.vocab().indexOf(str) < map.size()) {
                create.putRow(glove.vocab().indexOf(str), create2);
            }
        }
        return create;
    }

    private static float[] read(String[] strArr, int i) {
        float[] fArr = new float[i];
        for (int i2 = 1; i2 < strArr.length; i2++) {
            fArr[i2 - 1] = Float.parseFloat(strArr[i2]);
        }
        return fArr;
    }

    public double getCount(String str, String str2) {
        return this.coOccurrences.getCoOCurreneCounts().getCount(str, str2);
    }

    public CoOccurrences getCoOccurrences() {
        return this.coOccurrences;
    }

    public void setCoOccurrences(CoOccurrences coOccurrences) {
        this.coOccurrences = coOccurrences;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl, org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public GloveWeightLookupTable lookupTable() {
        return (GloveWeightLookupTable) this.lookupTable;
    }

    public void setLookupTable(GloveWeightLookupTable gloveWeightLookupTable) {
        this.lookupTable = gloveWeightLookupTable;
    }
}
