/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.text.functions;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.Accumulator;
import org.apache.spark.AccumulatorParam;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.spark.text.accumulators.WordFreqAccumulator;
import org.deeplearning4j.spark.text.functions.GetSentenceCountFunction;
import org.deeplearning4j.spark.text.functions.ReduceSentenceCount;
import org.deeplearning4j.spark.text.functions.TokenizerFunction;
import org.deeplearning4j.spark.text.functions.UpdateWordFreqAccumulatorFunction;
import org.deeplearning4j.spark.text.functions.WordsListToVocabWordsFunction;
import org.nd4j.common.primitives.AtomicDouble;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.primitives.Pair;

public class TextPipeline {
    private JavaRDD<String> corpusRDD;
    private int numWords;
    private int nGrams;
    private String tokenizer;
    private String tokenizerPreprocessor;
    private List<String> stopWords = new ArrayList<String>();
    private JavaSparkContext sc;
    private Accumulator<Counter<String>> wordFreqAcc;
    private Broadcast<List<String>> stopWordBroadCast;
    private JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD;
    private VocabCache<VocabWord> vocabCache = new AbstractCache();
    private Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast;
    private JavaRDD<List<VocabWord>> vocabWordListRDD;
    private JavaRDD<AtomicLong> sentenceCountRDD;
    private long totalWordCount;
    private boolean useUnk;
    private VectorsConfiguration configuration;

    public TextPipeline() {
    }

    public TextPipeline(JavaRDD<String> corpusRDD, Broadcast<Map<String, Object>> broadcasTokenizerVarMap) throws Exception {
        this.setRDDVarMap(corpusRDD, broadcasTokenizerVarMap);
        this.setup();
    }

    public void setRDDVarMap(JavaRDD<String> corpusRDD, Broadcast<Map<String, Object>> broadcasTokenizerVarMap) {
        Map tokenizerVarMap = (Map)broadcasTokenizerVarMap.getValue();
        this.corpusRDD = corpusRDD;
        this.numWords = (Integer)tokenizerVarMap.get("numWords");
        this.nGrams = (Integer)tokenizerVarMap.get("nGrams");
        this.tokenizer = (String)tokenizerVarMap.get("tokenizer");
        this.tokenizerPreprocessor = (String)tokenizerVarMap.get("tokenPreprocessor");
        this.useUnk = (Boolean)tokenizerVarMap.get("useUnk");
        this.configuration = (VectorsConfiguration)tokenizerVarMap.get("vectorsConfiguration");
        this.stopWords = (List)tokenizerVarMap.get("stopWords");
    }

    private void setup() {
        this.sc = new JavaSparkContext(this.corpusRDD.context());
        this.wordFreqAcc = this.sc.accumulator((Object)new Counter(), (AccumulatorParam)new WordFreqAccumulator());
        this.stopWordBroadCast = this.sc.broadcast(this.stopWords);
    }

    public JavaRDD<List<String>> tokenize() {
        if (this.corpusRDD == null) {
            throw new IllegalStateException("corpusRDD not assigned. Define TextPipeline with corpusRDD assigned.");
        }
        return this.corpusRDD.map((Function)new TokenizerFunction(this.tokenizer, this.tokenizerPreprocessor, this.nGrams));
    }

    public JavaRDD<Pair<List<String>, AtomicLong>> updateAndReturnAccumulatorVal(JavaRDD<List<String>> tokenizedRDD) {
        UpdateWordFreqAccumulatorFunction accumulatorClassFunction = new UpdateWordFreqAccumulatorFunction(this.stopWordBroadCast, this.wordFreqAcc);
        JavaRDD sentenceWordsCountRDD = tokenizedRDD.map((Function)accumulatorClassFunction);
        sentenceWordsCountRDD.count();
        return sentenceWordsCountRDD;
    }

    private String filterMinWord(String stringToken, double tokenCount) {
        return tokenCount < (double)this.numWords ? this.configuration.getUNK() : stringToken;
    }

    private void addTokenToVocabCache(String stringToken, Float tokenCount) {
        VocabWord actualToken;
        if (this.vocabCache.hasToken(stringToken)) {
            actualToken = (VocabWord)this.vocabCache.tokenFor(stringToken);
            actualToken.increaseElementFrequency(tokenCount.intValue());
        } else {
            actualToken = new VocabWord((double)tokenCount.floatValue(), stringToken);
        }
        boolean vocabContainsWord = this.vocabCache.containsWord(stringToken);
        if (!vocabContainsWord) {
            int idx = this.vocabCache.numWords();
            this.vocabCache.addToken((SequenceElement)actualToken);
            actualToken.setIndex(idx);
            this.vocabCache.putVocabWord(stringToken);
        }
    }

    public void filterMinWordAddVocab(Counter<String> wordFreq) {
        if (wordFreq.isEmpty()) {
            throw new IllegalStateException("IllegalStateException: wordFreqCounter has nothing. Check accumulator updating");
        }
        for (Map.Entry entry : wordFreq.entrySet()) {
            String stringToken = (String)entry.getKey();
            Double tokenCount = ((AtomicDouble)entry.getValue()).doubleValue();
            stringToken = this.filterMinWord(stringToken, tokenCount);
            if (!this.useUnk && stringToken.equals("UNK")) continue;
            this.addTokenToVocabCache(stringToken, Float.valueOf(tokenCount.floatValue()));
        }
    }

    public void buildVocabCache() {
        JavaRDD<List<String>> tokenizedRDD = this.tokenize();
        this.sentenceWordsCountRDD = this.updateAndReturnAccumulatorVal(tokenizedRDD).cache();
        Counter wordFreqCounter = (Counter)this.wordFreqAcc.value();
        this.filterMinWordAddVocab((Counter<String>)wordFreqCounter);
        Huffman huffman = new Huffman(this.vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(this.vocabCache);
        this.vocabCacheBroadcast = this.sc.broadcast(this.vocabCache);
    }

    public void buildVocabWordListRDD() {
        if (this.sentenceWordsCountRDD == null) {
            throw new IllegalStateException("SentenceWordCountRDD must be defined first. Run buildLookupCache first.");
        }
        this.vocabWordListRDD = this.sentenceWordsCountRDD.map((Function)new WordsListToVocabWordsFunction(this.vocabCacheBroadcast)).setName("vocabWordListRDD").cache();
        this.sentenceCountRDD = this.sentenceWordsCountRDD.map((Function)new GetSentenceCountFunction()).setName("sentenceCountRDD").cache();
        this.vocabWordListRDD.count();
        this.totalWordCount = ((AtomicLong)this.sentenceCountRDD.reduce((Function2)new ReduceSentenceCount())).get();
        this.sentenceWordsCountRDD.unpersist();
    }

    public Accumulator<Counter<String>> getWordFreqAcc() {
        if (this.wordFreqAcc != null) {
            return this.wordFreqAcc;
        }
        throw new IllegalStateException("IllegalStateException: wordFreqAcc not set at TextPipline.");
    }

    public Broadcast<VocabCache<VocabWord>> getBroadCastVocabCache() throws IllegalStateException {
        if (this.vocabCache.numWords() > 0) {
            return this.vocabCacheBroadcast;
        }
        throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
    }

    public VocabCache<VocabWord> getVocabCache() throws IllegalStateException {
        if (this.vocabCache != null && this.vocabCache.numWords() > 0) {
            return this.vocabCache;
        }
        throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
    }

    public JavaRDD<Pair<List<String>, AtomicLong>> getSentenceWordsCountRDD() {
        if (this.sentenceWordsCountRDD != null) {
            return this.sentenceWordsCountRDD;
        }
        throw new IllegalStateException("IllegalStateException: sentenceWordsCountRDD not set at TextPipline.");
    }

    public JavaRDD<List<VocabWord>> getVocabWordListRDD() throws IllegalStateException {
        if (this.vocabWordListRDD != null) {
            return this.vocabWordListRDD;
        }
        throw new IllegalStateException("IllegalStateException: vocabWordListRDD not set at TextPipline.");
    }

    public JavaRDD<AtomicLong> getSentenceCountRDD() throws IllegalStateException {
        if (this.sentenceCountRDD != null) {
            return this.sentenceCountRDD;
        }
        throw new IllegalStateException("IllegalStateException: sentenceCountRDD not set at TextPipline.");
    }

    public Long getTotalWordCount() {
        if (this.totalWordCount != 0L) {
            return this.totalWordCount;
        }
        throw new IllegalStateException("IllegalStateException: totalWordCount not set at TextPipline.");
    }
}

