/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction;
import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction;
import org.deeplearning4j.spark.text.functions.CountCumSum;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Word2Vec
extends WordVectorsImpl<VocabWord>
implements Serializable {
    private INDArray trainedSyn1;
    private static Logger log = LoggerFactory.getLogger(Word2Vec.class);
    private int MAX_EXP = 6;
    private double[] expTable;
    protected VectorsConfiguration configuration;
    private int nGrams = 1;
    private String tokenizer = "org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory";
    private String tokenPreprocessor = "org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor";
    private boolean removeStop = false;
    private long seed = 42L;
    private boolean useUnknown = false;

    protected Word2Vec(INDArray trainedSyn1) {
        this.trainedSyn1 = trainedSyn1;
        this.expTable = this.initExpTable();
    }

    protected Word2Vec() {
        this.expTable = this.initExpTable();
    }

    protected double[] initExpTable() {
        double[] expTable = new double[100000];
        for (int i = 0; i < expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)expTable.length * 2.0 - 1.0) * (double)this.MAX_EXP));
            expTable[i] = tmp / (tmp + 1.0);
        }
        return expTable;
    }

    public Map<String, Object> getTokenizerVarMap() {
        return new HashMap<String, Object>(){
            {
                this.put("numWords", Word2Vec.this.minWordFrequency);
                this.put("nGrams", Word2Vec.this.nGrams);
                this.put("tokenizer", Word2Vec.this.tokenizer);
                this.put("tokenPreprocessor", Word2Vec.this.tokenPreprocessor);
                this.put("removeStop", Word2Vec.this.removeStop);
                this.put("stopWords", Word2Vec.this.stopWords);
                this.put("useUnk", Word2Vec.this.useUnknown);
                this.put("vectorsConfiguration", Word2Vec.this.configuration);
            }
        };
    }

    public Map<String, Object> getWord2vecVarMap() {
        return new HashMap<String, Object>(){
            {
                this.put("vectorLength", Word2Vec.this.layerSize);
                this.put("useAdaGrad", Word2Vec.this.useAdeGrad);
                this.put("negative", Word2Vec.this.negative);
                this.put("window", Word2Vec.this.window);
                this.put("alpha", Word2Vec.this.learningRate.get());
                this.put("minAlpha", Word2Vec.this.minLearningRate);
                this.put("iterations", Word2Vec.this.numIterations);
                this.put("seed", Word2Vec.this.seed);
                this.put("maxExp", Word2Vec.this.MAX_EXP);
                this.put("batchSize", Word2Vec.this.batchSize);
            }
        };
    }

    public void train(JavaRDD<String> corpusRDD) throws Exception {
        log.info("Start training ...");
        if (this.workers > 0) {
            corpusRDD.repartition(this.workers);
        }
        JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
        Map<String, Object> tokenizerVarMap = this.getTokenizerVarMap();
        Map<String, Object> word2vecVarMap = this.getWord2vecVarMap();
        int maxRep = 1;
        log.info("Tokenization and building VocabCache ...");
        Broadcast broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
        TextPipeline pipeline = new TextPipeline(corpusRDD, (Broadcast<Map<String, Object>>)broadcastTokenizerVarMap);
        pipeline.buildVocabCache();
        pipeline.buildVocabWordListRDD();
        word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
        JavaRDD<AtomicLong> sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
        JavaRDD<List<VocabWord>> vocabWordListRDD = pipeline.getVocabWordListRDD();
        Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
        VocabCache vocabCache = (VocabCache)vocabCacheBroadcast.getValue();
        log.info("Vocab size: {}", (Object)vocabCache.numWords());
        log.info("Building Huffman Tree ...");
        log.info("Calculating cumulative sum of sentence counts ...");
        JavaRDD<Long> sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
        log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
        JavaPairRDD vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
        log.info("Broadcasting word2vec variables to workers ...");
        Broadcast word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
        Broadcast expTableBroadcast = sc.broadcast((Object)this.expTable);
        log.info("Training word2vec sentences ...");
        FirstIterationFunction firstIterFunc = new FirstIterationFunction((Broadcast<Map<String, Object>>)word2vecVarMapBroadcast, (Broadcast<double[]>)expTableBroadcast, vocabCacheBroadcast);
        JavaRDD indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions((FlatMapFunction)firstIterFunc).map((Function)new MapToPairFunction());
        List syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
        INDArray syn0 = Nd4j.zeros((int[])new int[]{vocabCache.numWords(), this.layerSize});
        log.info("Averaging results...");
        HashMap<Object, AtomicInteger> updates = new HashMap<Object, AtomicInteger>();
        HashMap<Long, Long> updaters = new HashMap<Long, Long>();
        for (Pair pair : syn0UpdateEntries) {
            syn0.getRow((long)((VocabWord)pair.getFirst()).getIndex()).addi((INDArray)pair.getSecond());
            if (updates.containsKey(pair.getFirst())) {
                ((AtomicInteger)updates.get(pair.getFirst())).incrementAndGet();
            } else {
                updates.put(pair.getFirst(), new AtomicInteger(1));
            }
            if (updaters.containsKey(((VocabWord)pair.getFirst()).getVocabId())) continue;
            updaters.put(((VocabWord)pair.getFirst()).getVocabId(), ((VocabWord)pair.getFirst()).getAffinityId());
        }
        for (Map.Entry entry : updates.entrySet()) {
            if (((AtomicInteger)entry.getValue()).get() <= 1) continue;
            if (((AtomicInteger)entry.getValue()).get() > maxRep) {
                maxRep = ((AtomicInteger)entry.getValue()).get();
            }
            syn0.getRow((long)((VocabWord)entry.getKey()).getIndex()).divi((Number)((AtomicInteger)entry.getValue()).get());
        }
        long totals = 0L;
        log.info("Finished calculations...");
        this.vocab = vocabCache;
        InMemoryLookupTable inMemoryLookupTable = new InMemoryLookupTable();
        Environment env = EnvironmentUtils.buildEnvironment();
        env.setNumCores(maxRep);
        env.setAvailableMemory(totals);
        this.update(env, Event.SPARK);
        inMemoryLookupTable.setVocab(vocabCache);
        inMemoryLookupTable.setVectorLength(this.layerSize);
        inMemoryLookupTable.setSyn0(syn0);
        this.lookupTable = inMemoryLookupTable;
        this.modelUtils.init(this.lookupTable);
    }

    public double[] getExpTable() {
        return this.expTable;
    }

    public VectorsConfiguration getConfiguration() {
        return this.configuration;
    }

    public static class Builder {
        protected int nGrams = 1;
        protected int numIterations = 1;
        protected int minWordFrequency = 1;
        protected int numEpochs = 1;
        protected double learningRate = 0.025;
        protected double minLearningRate = 0.001;
        protected int windowSize = 5;
        protected double negative = 0.0;
        protected double sampling = 1.0E-5;
        protected long seed = 42L;
        protected boolean useAdaGrad = false;
        protected TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        protected VectorsConfiguration configuration = new VectorsConfiguration();
        protected int layerSize;
        protected List<String> stopWords = new ArrayList<String>();
        protected int batchSize = 100;
        protected boolean useUnk = false;
        private String tokenizer = "";
        private String tokenPreprocessor = "";
        private int workers = 0;

        public Builder() {
            this(new VectorsConfiguration());
        }

        public Builder(VectorsConfiguration configuration) {
            this.configuration = configuration;
            this.numIterations = configuration.getIterations();
            this.numEpochs = configuration.getEpochs();
            this.minLearningRate = configuration.getMinLearningRate();
            this.learningRate = configuration.getLearningRate();
            this.sampling = configuration.getSampling();
            this.negative = configuration.getNegative();
            this.minWordFrequency = configuration.getMinWordFrequency();
            this.seed = configuration.getSeed();
            this.batchSize = configuration.getBatchSize();
            this.layerSize = configuration.getLayersSize();
            this.useAdaGrad = configuration.isUseAdaGrad();
            this.windowSize = configuration.getWindow();
            if (configuration.getStopList() != null) {
                this.stopWords.addAll(configuration.getStopList());
            }
        }

        public Builder windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder negative(int negative) {
            this.negative = negative;
            return this;
        }

        public Builder sampling(double sampling) {
            this.sampling = sampling;
            return this;
        }

        public Builder learningRate(double lr) {
            this.learningRate = lr;
            return this;
        }

        public Builder minLearningRate(double mlr) {
            this.minLearningRate = mlr;
            return this;
        }

        public Builder iterations(int numIterations) {
            this.numIterations = numIterations;
            return this;
        }

        public Builder epochs(int numEpochs) {
            this.numEpochs = numEpochs;
            return this;
        }

        public Builder minWordFrequency(int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
            return this;
        }

        public Builder useAdaGrad(boolean reallyUse) {
            this.useAdaGrad = reallyUse;
            return this;
        }

        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public Builder tokenizerFactory(@NonNull TokenizerFactory factory) {
            if (factory == null) {
                throw new NullPointerException("factory is marked non-null but is null");
            }
            this.tokenizer = factory.getClass().getCanonicalName();
            this.tokenPreprocessor = factory.getTokenPreProcessor() != null ? factory.getTokenPreProcessor().getClass().getCanonicalName() : "";
            return this;
        }

        public Builder tokenizerFactory(@NonNull String tokenizer) {
            if (tokenizer == null) {
                throw new NullPointerException("tokenizer is marked non-null but is null");
            }
            this.tokenizer = tokenizer;
            return this;
        }

        public Builder tokenPreprocessor(@NonNull String tokenPreprocessor) {
            if (tokenPreprocessor == null) {
                throw new NullPointerException("tokenPreprocessor is marked non-null but is null");
            }
            this.tokenPreprocessor = tokenPreprocessor;
            return this;
        }

        public Builder workers(int workers) {
            this.workers = workers;
            return this;
        }

        public Builder layerSize(int layerSize) {
            this.layerSize = layerSize;
            return this;
        }

        public Builder setNGrams(int nGrams) {
            this.nGrams = nGrams;
            return this;
        }

        public Builder stopWords(@NonNull List<String> stopWords) {
            if (stopWords == null) {
                throw new NullPointerException("stopWords is marked non-null but is null");
            }
            for (String word : stopWords) {
                if (this.stopWords.contains(word)) continue;
                this.stopWords.add(word);
            }
            return this;
        }

        public Builder batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Builder useUnknown(boolean reallyUse) {
            this.useUnk = reallyUse;
            return this;
        }

        public Word2Vec build() {
            Word2Vec ret = new Word2Vec();
            this.configuration.setLearningRate(this.learningRate);
            this.configuration.setLayersSize(this.layerSize);
            this.configuration.setWindow(this.windowSize);
            this.configuration.setMinWordFrequency(this.minWordFrequency);
            this.configuration.setIterations(this.numIterations);
            this.configuration.setSeed(this.seed);
            this.configuration.setMinLearningRate(this.minLearningRate);
            this.configuration.setSampling(this.sampling);
            this.configuration.setUseAdaGrad(this.useAdaGrad);
            this.configuration.setNegative(this.negative);
            this.configuration.setEpochs(this.numEpochs);
            this.configuration.setBatchSize(this.batchSize);
            this.configuration.setStopList(this.stopWords);
            ret.workers = this.workers;
            ret.nGrams = this.nGrams;
            ret.configuration = this.configuration;
            ret.numEpochs = this.numEpochs;
            ret.numIterations = this.numIterations;
            ret.minWordFrequency = this.minWordFrequency;
            ret.learningRate.set(this.learningRate);
            ret.minLearningRate = this.minLearningRate;
            ret.sampling = this.sampling;
            ret.negative = this.negative;
            ret.layerSize = this.layerSize;
            ret.window = this.windowSize;
            ret.useAdeGrad = this.useAdaGrad;
            ret.stopWords = this.stopWords;
            ret.batchSize = this.batchSize;
            ret.useUnknown = this.useUnk;
            ret.tokenizer = this.tokenizer;
            ret.tokenPreprocessor = this.tokenPreprocessor;
            return ret;
        }
    }
}

