package org.deeplearning4j.models.word2vec.wordstore.inmemory;

import it.unimi.dsi.util.XorShift64StarRandomGenerator;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.plot.Tsne;
import org.deeplearning4j.plot.dropwizard.RenderApplication;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.util.Index;
import org.deeplearning4j.util.SerializationUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/wordstore/inmemory/InMemoryLookupCache.class */
public class InMemoryLookupCache implements VocabCache, Serializable {
    private Index wordIndex;
    private boolean useAdaGrad;
    private Counter<String> wordFrequencies;
    private Counter<String> docFrequencies;
    private Map<String, VocabWord> vocabs;
    private Map<String, VocabWord> tokens;
    private Map<Integer, INDArray> codes;
    private INDArray syn0;
    private INDArray syn1;
    private int vectorLength;
    private transient RandomGenerator rng;
    private AtomicInteger totalWordOccurrences;
    private double lr;
    double[] expTable;
    static double MAX_EXP = 6.0d;
    private long seed;
    private int numDocs;

    public InMemoryLookupCache(int i) {
        this(i, true);
        initExpTable();
    }

    public InMemoryLookupCache(int i, int i2) {
        this.wordIndex = new Index();
        this.useAdaGrad = false;
        this.wordFrequencies = Util.parallelCounter();
        this.docFrequencies = Util.parallelCounter();
        this.vocabs = new ConcurrentHashMap();
        this.tokens = new ConcurrentHashMap();
        this.codes = new ConcurrentHashMap();
        this.vectorLength = 50;
        this.rng = new XorShift64StarRandomGenerator(123L);
        this.totalWordOccurrences = new AtomicInteger(0);
        this.lr = 0.10000000149011612d;
        this.expTable = new double[1000];
        this.seed = 123L;
        this.numDocs = 0;
        this.vectorLength = i;
        this.syn0 = Nd4j.rand(i2, i);
    }

    public InMemoryLookupCache(int i, boolean z) {
        this(i, z, 0.02500000037252903d, new XorShift64StarRandomGenerator(123L));
        addWordToIndex(0, "UNK");
        this.wordIndex.add("UNK");
    }

    public InMemoryLookupCache(int i, boolean z, double d, RandomGenerator randomGenerator) {
        this.wordIndex = new Index();
        this.useAdaGrad = false;
        this.wordFrequencies = Util.parallelCounter();
        this.docFrequencies = Util.parallelCounter();
        this.vocabs = new ConcurrentHashMap();
        this.tokens = new ConcurrentHashMap();
        this.codes = new ConcurrentHashMap();
        this.vectorLength = 50;
        this.rng = new XorShift64StarRandomGenerator(123L);
        this.totalWordOccurrences = new AtomicInteger(0);
        this.lr = 0.10000000149011612d;
        this.expTable = new double[1000];
        this.seed = 123L;
        this.numDocs = 0;
        this.vectorLength = i;
        this.useAdaGrad = z;
        this.lr = d;
        this.rng = randomGenerator;
        initExpTable();
    }

    public InMemoryLookupCache(int i, boolean z, double d) {
        this(i, z, d, new XorShift64StarRandomGenerator(123L));
    }

    private void initExpTable() {
        for (int i = 0; i < this.expTable.length; i++) {
            double exp = FastMath.exp((((i / this.expTable.length) * 2.0d) - 1.0d) * MAX_EXP);
            this.expTable[i] = exp / (exp + 1.0d);
        }
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void iterate(VocabWord vocabWord, VocabWord vocabWord2) {
        int length;
        INDArray slice = this.syn0.slice(vocabWord2.getIndex());
        INDArray create = Nd4j.create(this.vectorLength);
        double d = 0.0d;
        for (int i = 0; i < vocabWord.getCodeLength(); i++) {
            int i2 = vocabWord.getCodes()[i];
            int i3 = vocabWord.getPoints()[i];
            if (i3 >= this.syn0.rows()) {
                throw new IllegalStateException("Illegal point " + i3);
            }
            INDArray slice2 = this.syn1.slice(i3);
            double dot = Nd4j.getBlasWrapper().dot(slice, slice2);
            if (dot >= (-MAX_EXP) && dot < MAX_EXP && (length = (int) ((dot + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))) < this.expTable.length) {
                double d2 = ((1 - i2) - this.expTable[length]) * this.lr;
                d += d2;
                if (this.syn0.data().dataType().equals("double")) {
                    Nd4j.getBlasWrapper().axpy(d2, slice2, create);
                    Nd4j.getBlasWrapper().axpy(d2, slice, slice2);
                } else {
                    Nd4j.getBlasWrapper().axpy((float) d2, slice2, create);
                    Nd4j.getBlasWrapper().axpy((float) d2, slice, slice2);
                }
            }
        }
        double length2 = d / vocabWord.getCodes().length;
        if (this.useAdaGrad) {
            if (this.syn0.data().dataType().equals("double")) {
                Nd4j.getBlasWrapper().axpy(length2, create, slice);
                return;
            } else {
                Nd4j.getBlasWrapper().axpy((float) length2, create, slice);
                return;
            }
        }
        if (this.syn0.data().dataType().equals("double")) {
            Nd4j.getBlasWrapper().axpy(1.0d, create, slice);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, create, slice);
        }
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public synchronized Collection<String> words() {
        return this.vocabs.keySet();
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void resetWeights() {
        this.rng = new MersenneTwister(this.seed);
        this.syn0 = Nd4j.rand(new int[]{this.vocabs.size(), this.vectorLength}, this.rng).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(this.vectorLength));
        this.syn1 = Nd4j.create(this.syn0.shape());
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public synchronized void incrementWordCount(String str) {
        incrementWordCount(str, 1);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public synchronized void incrementWordCount(String str, int i) {
        this.wordFrequencies.incrementCount(str, 1.0d);
        (hasToken(str) ? tokenFor(str) : new VocabWord(i, str)).increment(i);
        this.totalWordOccurrences.set(this.totalWordOccurrences.get() + i);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public int wordFrequency(String str) {
        return (int) this.wordFrequencies.getCount(str);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public boolean containsWord(String str) {
        return this.vocabs.containsKey(str);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public String wordAtIndex(int i) {
        return (String) this.wordIndex.get(i);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public int indexOf(String str) {
        return this.wordIndex.indexOf(str);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void putCode(int i, INDArray iNDArray) {
        this.codes.put(Integer.valueOf(i), iNDArray);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public INDArray loadCodes(int[] iArr) {
        return this.syn1.getRows(iArr);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public Collection<VocabWord> vocabWords() {
        return this.vocabs.values();
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public int totalWordOccurrences() {
        return this.totalWordOccurrences.get();
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void putVector(String str, INDArray iNDArray) {
        if (str == null) {
            throw new IllegalArgumentException("No null words allowed");
        }
        if (iNDArray == null) {
            throw new IllegalArgumentException("No null vectors allowed");
        }
        this.syn0.slice(indexOf(str)).assign(iNDArray);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public INDArray vector(String str) {
        if (str == null) {
            return null;
        }
        return this.syn0.getRow(indexOf(str));
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public VocabWord wordFor(String str) {
        return this.vocabs.get(str);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public synchronized void addWordToIndex(int i, String str) {
        if (!this.wordFrequencies.containsKey(str)) {
            this.wordFrequencies.incrementCount(str, 1.0d);
        }
        this.wordIndex.add(str, i);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public synchronized void putVocabWord(String str) {
        VocabWord vocabWord = tokenFor(str);
        addWordToIndex(vocabWord.getIndex(), str);
        if (!hasToken(str)) {
            throw new IllegalStateException("Unable to add token " + str + " when not already a token");
        }
        this.vocabs.put(str, vocabWord);
        this.wordIndex.add(str, vocabWord.getIndex());
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public synchronized int numWords() {
        return this.vocabs.size();
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public int docAppearedIn(String str) {
        return (int) this.docFrequencies.getCount(str);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void incrementDocCount(String str, int i) {
        this.docFrequencies.incrementCount(str, i);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void setCountForDoc(String str, int i) {
        this.docFrequencies.setCount(str, i);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public int totalNumberOfDocs() {
        return this.numDocs;
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void incrementTotalDocCount() {
        this.numDocs++;
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void incrementTotalDocCount(int i) {
        this.numDocs += i;
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public Collection<VocabWord> tokens() {
        return this.tokens.values();
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void addToken(VocabWord vocabWord) {
        this.tokens.put(vocabWord.getWord(), vocabWord);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public VocabWord tokenFor(String str) {
        return this.tokens.get(str);
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public boolean hasToken(String str) {
        return tokenFor(str) != null;
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void saveVocab() {
        SerializationUtils.saveObject(this, new File("cache.ser"));
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public boolean vocabExists() {
        return new File("cache.ser").exists();
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void plotVocab(Tsne tsne) {
        try {
            ArrayList arrayList = new ArrayList();
            Iterator<String> it = words().iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            tsne.plot(this.syn0, 2, arrayList);
            try {
                RenderApplication.main(null);
            } catch (Exception e) {
                e.printStackTrace();
            }
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void plotVocab() {
        Tsne build = new Tsne.Builder().normalize(false).setFinalMomentum(0.800000011920929d).setMaxIter(1000).build();
        try {
            ArrayList arrayList = new ArrayList();
            Iterator<String> it = words().iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            build.plot(this.syn0, 2, arrayList);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.models.word2vec.wordstore.VocabCache
    public void loadVocab() {
        InMemoryLookupCache inMemoryLookupCache = (InMemoryLookupCache) SerializationUtils.readObject(new File("cache.ser"));
        this.codes = inMemoryLookupCache.codes;
        this.vocabs = inMemoryLookupCache.vocabs;
        this.vectorLength = inMemoryLookupCache.vectorLength;
        this.wordFrequencies = inMemoryLookupCache.wordFrequencies;
        this.wordIndex = inMemoryLookupCache.wordIndex;
        this.tokens = inMemoryLookupCache.tokens;
    }

    public static void writeTsneFormat(Word2Vec word2Vec, INDArray iNDArray, File file) throws Exception {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file));
        InMemoryLookupCache inMemoryLookupCache = (InMemoryLookupCache) word2Vec.getCache();
        for (int i = 0; i < iNDArray.rows(); i++) {
            String wordAtIndex = inMemoryLookupCache.wordAtIndex(i);
            if (wordAtIndex != null) {
                StringBuffer stringBuffer = new StringBuffer();
                INDArray row = iNDArray.getRow(i);
                for (int i2 = 0; i2 < row.length(); i2++) {
                    stringBuffer.append(row.getDouble(i2));
                    if (i2 < row.length() - 1) {
                        stringBuffer.append(",");
                    }
                }
                stringBuffer.append(",");
                stringBuffer.append(wordAtIndex);
                stringBuffer.append(" ");
                stringBuffer.append("\n");
                bufferedWriter.write(stringBuffer.toString());
            }
        }
        System.out.println("Wrote 0 with size of " + word2Vec.getLayerSize());
        bufferedWriter.flush();
        bufferedWriter.close();
    }

    public RandomGenerator getRng() {
        return this.rng;
    }

    public void setRng(RandomGenerator randomGenerator) {
        this.rng = randomGenerator;
    }

    public INDArray getSyn0() {
        return this.syn0;
    }

    public void setSyn0(INDArray iNDArray) {
        this.syn0 = iNDArray;
    }

    public INDArray getSyn1() {
        return this.syn1;
    }

    public void setSyn1(INDArray iNDArray) {
        this.syn1 = iNDArray;
    }
}
