package org.deeplearning4j.models.embeddings.wordvectors;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.SetUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.class */
public class WordVectorsImpl implements WordVectors {
    protected WeightLookupTable lookupTable;
    protected VocabCache vocab;
    public static final String UNK = "UNK";
    protected int minWordFrequency = 5;
    protected int layerSize = 100;
    protected List<String> stopWords = StopWords.getStopWords();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl$ArrayComparator.class */
    public static class ArrayComparator implements Comparator<Double[]> {
        private ArrayComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Double[] dArr, Double[] dArr2) {
            return Double.compare(dArr[0].doubleValue(), dArr2[0].doubleValue());
        }
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public boolean hasWord(String str) {
        return vocab().indexOf(str) >= 0;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearestSum(Collection<String> collection, Collection<String> collection2, int i) {
        INDArray create = Nd4j.create(lookupTable().layerSize());
        Set union = SetUtils.union(new HashSet(collection), new HashSet(collection2));
        Iterator<String> it = collection.iterator();
        while (it.hasNext()) {
            create.addi(lookupTable().vector(it.next()));
        }
        Iterator<String> it2 = collection2.iterator();
        while (it2.hasNext()) {
            create.addi(this.lookupTable.vector(it2.next()).mul(-1));
        }
        if (!(lookupTable() instanceof InMemoryLookupTable)) {
            Counter counter = new Counter();
            for (String str : vocab().words()) {
                counter.incrementCount(str, Transforms.cosineSim(create, getWordVectorMatrix(str)));
            }
            counter.keepTopNKeys(i);
            return counter.keySet();
        }
        INDArray syn0 = ((InMemoryLookupTable) lookupTable()).getSyn0();
        INDArray iNDArray = Nd4j.sortWithIndices(syn0.mulRowVector(syn0.norm2(new int[]{0}).rdivi(1).muli(create)).sum(new int[]{1}), 0, false)[0];
        ArrayList arrayList = new ArrayList();
        if (i > iNDArray.length()) {
            i = iNDArray.length();
        }
        int i2 = i;
        for (int i3 = 0; i3 < i2; i3++) {
            if (union.contains(this.vocab.wordAtIndex(iNDArray.getInt(new int[]{i3})))) {
                i2++;
                if (i2 >= iNDArray.length()) {
                    break;
                }
            } else {
                String wordAtIndex = vocab().wordAtIndex(iNDArray.getInt(new int[]{i3}));
                if (wordAtIndex == null || wordAtIndex.equals("UNK") || wordAtIndex.equals("STOP")) {
                    i2++;
                    if (i2 >= iNDArray.length()) {
                        break;
                    }
                } else {
                    arrayList.add(vocab().wordAtIndex(iNDArray.getInt(new int[]{i3})));
                }
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearestSum(INDArray iNDArray, int i) {
        if (!(lookupTable() instanceof InMemoryLookupTable)) {
            Counter counter = new Counter();
            for (String str : vocab().words()) {
                counter.incrementCount(str, Transforms.cosineSim(iNDArray, getWordVectorMatrix(str)));
            }
            counter.keepTopNKeys(i);
            return counter.keySet();
        }
        INDArray syn0 = ((InMemoryLookupTable) lookupTable()).getSyn0();
        INDArray iNDArray2 = Nd4j.sortWithIndices(syn0.mulRowVector(syn0.norm2(new int[]{0}).rdivi(1).muli(iNDArray)).sum(new int[]{1}), 0, false)[0];
        ArrayList arrayList = new ArrayList();
        if (i > iNDArray2.length()) {
            i = iNDArray2.length();
        }
        int i2 = i;
        for (int i3 = 0; i3 < i2; i3++) {
            String wordAtIndex = vocab().wordAtIndex(iNDArray2.getInt(new int[]{i3}));
            if (wordAtIndex == null || wordAtIndex.equals("UNK") || wordAtIndex.equals("STOP")) {
                i2++;
                if (i2 >= iNDArray2.length()) {
                    break;
                }
            } else {
                arrayList.add(vocab().wordAtIndex(iNDArray2.getInt(new int[]{i3})));
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearest(INDArray iNDArray, int i) {
        if (!(lookupTable() instanceof InMemoryLookupTable)) {
            Counter counter = new Counter();
            for (String str : vocab().words()) {
                counter.incrementCount(str, Transforms.cosineSim(iNDArray, getWordVectorMatrix(str)));
            }
            counter.keepTopNKeys(i);
            return counter.keySet();
        }
        INDArray syn0 = ((InMemoryLookupTable) lookupTable()).getSyn0();
        INDArray iNDArray2 = Nd4j.sortWithIndices(syn0.mulRowVector(syn0.norm2(new int[]{0}).rdivi(1).muli(iNDArray)).mean(new int[]{1}), 0, false)[0];
        ArrayList arrayList = new ArrayList();
        if (i > iNDArray2.length()) {
            i = iNDArray2.length();
        }
        int i2 = i;
        for (int i3 = 0; i3 < i2; i3++) {
            VocabCache vocab = vocab();
            int i4 = iNDArray2.getInt(new int[]{0, i3});
            String wordAtIndex = vocab.wordAtIndex(i4);
            if (wordAtIndex == null || wordAtIndex.equals("UNK") || wordAtIndex.equals("STOP")) {
                i2++;
                if (i2 >= iNDArray2.length()) {
                    break;
                }
            } else {
                arrayList.add(vocab.wordAtIndex(i4));
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearestSum(String str, int i) {
        String wordAtIndex;
        INDArray unitVec = Transforms.unitVec(getWordVectorMatrix(str));
        if (!(lookupTable() instanceof InMemoryLookupTable)) {
            if (unitVec == null) {
                return new ArrayList();
            }
            Counter counter = new Counter();
            for (String str2 : vocab().words()) {
                if (!str2.equals(str)) {
                    counter.incrementCount(str2, Transforms.cosineSim(unitVec, getWordVectorMatrix(str2)));
                }
            }
            counter.keepTopNKeys(i);
            return counter.keySet();
        }
        INDArray syn0 = ((InMemoryLookupTable) lookupTable()).getSyn0();
        INDArray iNDArray = Nd4j.sortWithIndices(syn0.mulRowVector(syn0.norm2(new int[]{0}).rdivi(1).muli(unitVec)).sum(new int[]{1}), 0, false)[0];
        ArrayList arrayList = new ArrayList();
        VocabWord wordFor = vocab().wordFor(str);
        if (i > iNDArray.length()) {
            i = iNDArray.length();
        }
        for (int i2 = 0; i2 < i + 1; i2++) {
            if (iNDArray.getInt(new int[]{i2}) != wordFor.getIndex() && (wordAtIndex = vocab().wordAtIndex(iNDArray.getInt(new int[]{i2}))) != null && !wordAtIndex.equals("UNK") && !wordAtIndex.equals("STOP")) {
                arrayList.add(vocab().wordAtIndex(iNDArray.getInt(new int[]{i2})));
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Map<String, Double> accuracy(List<String> list) {
        HashMap hashMap = new HashMap();
        Counter counter = new Counter();
        for (String str : list) {
            if (str.startsWith(":")) {
                double count = counter.getCount("correct");
                hashMap.put(str, Double.valueOf((100.0d * count) / (count / counter.getCount("wrong"))));
                counter.clear();
            } else {
                String[] split = str.split(" ");
                if (split[3].equals(wordsNearest(Arrays.asList(split[0]), Arrays.asList(split[1], split[2]), 1).iterator().next())) {
                    counter.incrementCount("right", 1.0d);
                } else {
                    counter.incrementCount("wrong", 1.0d);
                }
            }
        }
        return hashMap;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public int indexOf(String str) {
        return vocab().indexOf(str);
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public List<String> similarWordsInVocabTo(String str, double d) {
        ArrayList arrayList = new ArrayList();
        for (String str2 : this.vocab.words()) {
            if (MathUtils.stringSimilarity(new String[]{str, str2}) >= d) {
                arrayList.add(str2);
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public double[] getWordVector(String str) {
        return vocab().indexOf(str) < 0 ? this.lookupTable.vector("UNK").dup().data().asDouble() : this.lookupTable.vector(str).dup().data().asDouble();
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public INDArray getWordVectorMatrixNormalized(String str) {
        if (vocab().indexOf(str) < 0) {
            return lookupTable().vector("UNK");
        }
        INDArray vector = lookupTable().vector(str);
        return vector.div(Double.valueOf(Nd4j.getBlasWrapper().nrm2(vector)));
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public INDArray getWordVectorMatrix(String str) {
        return lookupTable().vector(str);
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearest(Collection<String> collection, Collection<String> collection2, int i) {
        Iterator it = SetUtils.union(new HashSet(collection), new HashSet(collection2)).iterator();
        while (it.hasNext()) {
            if (!vocab().containsWord((String) it.next())) {
                return new ArrayList();
            }
        }
        WeightLookupTable lookupTable = lookupTable();
        INDArray create = Nd4j.create(collection.size() + collection2.size(), lookupTable.layerSize());
        int i2 = 0;
        Set union = SetUtils.union(new HashSet(collection), new HashSet(collection2));
        Iterator<String> it2 = collection.iterator();
        while (it2.hasNext()) {
            int i3 = i2;
            i2++;
            create.putRow(i3, lookupTable.vector(it2.next()));
        }
        Iterator<String> it3 = collection2.iterator();
        while (it3.hasNext()) {
            int i4 = i2;
            i2++;
            create.putRow(i4, lookupTable.vector(it3.next()).mul(-1));
        }
        INDArray mean = create.isMatrix() ? create.mean(new int[]{0}) : create;
        if (!(lookupTable instanceof InMemoryLookupTable)) {
            Counter counter = new Counter();
            for (String str : vocab().words()) {
                counter.incrementCount(str, Transforms.cosineSim(mean, getWordVectorMatrix(str)));
            }
            counter.keepTopNKeys(i);
            return counter.keySet();
        }
        INDArray syn0 = ((InMemoryLookupTable) lookupTable).getSyn0();
        syn0.diviRowVector(syn0.norm2(new int[]{0}));
        List<Double> topN = getTopN(Transforms.unitVec(mean).mmul(syn0.transpose()), i + union.size());
        ArrayList arrayList = new ArrayList();
        for (int i5 = 0; i5 < topN.size(); i5++) {
            String wordAtIndex = vocab().wordAtIndex(topN.get(i5).intValue());
            if (wordAtIndex != null && !wordAtIndex.equals("UNK") && !wordAtIndex.equals("STOP") && !union.contains(wordAtIndex)) {
                arrayList.add(wordAtIndex);
                if (arrayList.size() >= i) {
                    break;
                }
            }
        }
        return arrayList;
    }

    private static List<Double> getTopN(INDArray iNDArray, int i) {
        ArrayComparator arrayComparator = new ArrayComparator();
        PriorityQueue priorityQueue = new PriorityQueue(iNDArray.rows(), arrayComparator);
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            Double[] dArr = {Double.valueOf(iNDArray.getDouble(i2)), Double.valueOf(i2)};
            if (priorityQueue.size() < i) {
                priorityQueue.add(dArr);
            } else if (arrayComparator.compare(dArr, (Double[]) priorityQueue.peek()) > 0) {
                priorityQueue.poll();
                priorityQueue.add(dArr);
            }
        }
        ArrayList arrayList = new ArrayList();
        while (!priorityQueue.isEmpty()) {
            arrayList.add(Double.valueOf(((Double[]) priorityQueue.poll())[1].doubleValue()));
        }
        return Lists.reverse(arrayList);
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public Collection<String> wordsNearest(String str, int i) {
        return wordsNearest(Arrays.asList(str), new ArrayList(), i);
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public double similarity(String str, String str2) {
        if (str.equals(str2)) {
            return 1.0d;
        }
        INDArray unitVec = Transforms.unitVec(getWordVectorMatrix(str));
        INDArray unitVec2 = Transforms.unitVec(getWordVectorMatrix(str2));
        if (unitVec == null || unitVec2 == null) {
            return -1.0d;
        }
        return Nd4j.getBlasWrapper().dot(unitVec, unitVec2);
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public VocabCache vocab() {
        return this.vocab;
    }

    @Override // org.deeplearning4j.models.embeddings.wordvectors.WordVectors
    public WeightLookupTable lookupTable() {
        return this.lookupTable;
    }

    public void setLookupTable(WeightLookupTable weightLookupTable) {
        this.lookupTable = weightLookupTable;
    }

    public void setVocab(VocabCache vocabCache) {
        this.vocab = vocabCache;
    }
}
