/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public class Word2VecPerformer
implements VoidFunction<Pair<List<VocabWord>, AtomicLong>> {
    private static double MAX_EXP = 6.0;
    private boolean useAdaGrad = false;
    private double negative = 5.0;
    private int numWords = 1;
    private INDArray table;
    private int window = 5;
    private AtomicLong nextRandom = new AtomicLong(5L);
    private double alpha = 0.025;
    private double minAlpha = 0.01;
    private int totalWords = 1;
    private static final transient Logger log = LoggerFactory.getLogger(Word2VecPerformer.class);
    private int lastChecked = 0;
    private Broadcast<AtomicLong> wordCount;
    private InMemoryLookupTable weights;
    private double[] expTable = new double[1000];
    private int vectorLength;

    public Word2VecPerformer(SparkConf sc, Broadcast<AtomicLong> wordCount, InMemoryLookupTable weights) {
        this.weights = weights;
        this.wordCount = wordCount;
        this.setup(sc);
    }

    public void setup(SparkConf conf) {
        this.useAdaGrad = conf.getBoolean("org.deeplearning4j.scaleout.perform.models.word2vec.adagrad", false);
        this.negative = conf.getDouble("org.deeplearning4j.scaleout.perform.models.word2vec.negative", 5.0);
        this.numWords = conf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.numwords", 1);
        this.window = conf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.window", 5);
        this.alpha = conf.getDouble("org.deeplearning4j.scaleout.perform.models.word2vec.alpha", (double)0.025f);
        this.minAlpha = conf.getDouble("org.deeplearning4j.scaleout.perform.models.word2vec.minalpha", (double)0.01f);
        this.totalWords = conf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.numwords", 1);
        this.vectorLength = conf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.length", 100);
        this.initExpTable();
        if (this.negative > 0.0 && conf.contains("org.deeplearning4j.scaleout.perform.models.word2vec.table")) {
            ByteArrayInputStream bis = new ByteArrayInputStream(conf.get("org.deeplearning4j.scaleout.perform.models.word2vec.table").getBytes());
            DataInputStream dis = new DataInputStream(bis);
            this.table = Nd4j.read((DataInputStream)dis);
        }
    }

    public void trainSentence(List<VocabWord> sentence, double alpha) {
        if (sentence != null && !sentence.isEmpty()) {
            for (int i = 0; i < sentence.size(); ++i) {
                if (sentence.get(i).getWord().endsWith("STOP")) continue;
                this.nextRandom.set(this.nextRandom.get() * 25214903917L + 11L);
                this.skipGram(i, sentence, (int)this.nextRandom.get() % this.window, alpha);
            }
        }
    }

    public void skipGram(int i, List<VocabWord> sentence, int b, double alpha) {
        VocabWord word = sentence.get(i);
        if (word != null && !sentence.isEmpty()) {
            int end = this.window * 2 + 1 - b;
            for (int a = b; a < end; ++a) {
                int c;
                if (a == this.window || (c = i - this.window + a) < 0 || c >= sentence.size()) continue;
                VocabWord lastWord = sentence.get(c);
                this.iterateSample(word, lastWord, alpha);
            }
        }
    }

    public void iterateSample(VocabWord w1, VocabWord w2, double alpha) {
        INDArray neu1e;
        INDArray l1;
        block17: {
            if (w2 == null || w2.getIndex() < 0) {
                return;
            }
            l1 = this.weights.vector(w2.getWord());
            neu1e = Nd4j.create((int)this.vectorLength);
            for (int i = 0; i < w1.getCodeLength(); ++i) {
                int idx;
                byte code = (Byte)w1.getCodes().get(i);
                int point = (Integer)w1.getPoints().get(i);
                INDArray syn1 = this.weights.getSyn1().slice((long)point);
                double dot = Nd4j.getBlasWrapper().dot(l1, syn1);
                if (!(dot >= -MAX_EXP) || !(dot < MAX_EXP) || (idx = (int)((dot + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))) >= this.expTable.length) continue;
                double f = this.expTable[idx];
                double g = ((double)(1 - code) - f) * (this.useAdaGrad ? w1.getGradient(i, alpha, this.alpha) : alpha);
                Nd4j.getBlasWrapper().level1().axpy(l1.length(), g, syn1, neu1e);
                Nd4j.getBlasWrapper().level1().axpy(l1.length(), g, l1, syn1);
            }
            if (!(this.negative > 0.0)) break block17;
            int target = w1.getIndex();
            INDArray syn1Neg = this.weights.getSyn1Neg().slice((long)target);
            int d = 0;
            while ((double)d < this.negative + 1.0) {
                block20: {
                    double g;
                    int label;
                    block19: {
                        block18: {
                            if (d != 0) break block18;
                            label = 1;
                            break block19;
                        }
                        this.nextRandom.set(this.nextRandom.get() * 25214903917L + 11L);
                        target = this.table.getInt(new int[]{(int)(this.nextRandom.get() >> 16) % (int)this.table.length()});
                        if (target == 0) {
                            target = (int)this.nextRandom.get() % (this.numWords - 1) + 1;
                        }
                        if (target == w1.getIndex()) break block20;
                        label = 0;
                    }
                    double f = Nd4j.getBlasWrapper().dot(l1, syn1Neg);
                    if (f > MAX_EXP) {
                        g = this.useAdaGrad ? w1.getGradient(target, (double)(label - 1), this.alpha) : (double)(label - 1) * alpha;
                    } else if (f < -MAX_EXP) {
                        g = (double)label * (this.useAdaGrad ? w1.getGradient(target, alpha, this.alpha) : alpha);
                    } else {
                        double d2 = g = this.useAdaGrad ? w1.getGradient(target, (double)label - this.expTable[(int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))], this.alpha) : ((double)label - this.expTable[(int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))]) * alpha;
                    }
                    if (syn1Neg.data().dataType() == DataType.DOUBLE) {
                        Nd4j.getBlasWrapper().axpy(g, neu1e, l1);
                    } else {
                        Nd4j.getBlasWrapper().axpy((float)g, neu1e, l1);
                    }
                    if (syn1Neg.data().dataType() == DataType.DOUBLE) {
                        Nd4j.getBlasWrapper().axpy(g, syn1Neg, l1);
                    } else {
                        Nd4j.getBlasWrapper().axpy((float)g, syn1Neg, l1);
                    }
                }
                ++d;
            }
        }
        if (neu1e.data().dataType() == DataType.DOUBLE) {
            Nd4j.getBlasWrapper().axpy(1.0, neu1e, l1);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, neu1e, l1);
        }
    }

    private void initExpTable() {
        for (int i = 0; i < this.expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)this.expTable.length * 2.0 - 1.0) * MAX_EXP));
            this.expTable[i] = tmp / (tmp + 1.0);
        }
    }

    public void call(Pair<List<VocabWord>, AtomicLong> pair) throws Exception {
        double numWordsSoFar = ((AtomicLong)this.wordCount.getValue()).doubleValue();
        List sentence = (List)pair.getFirst();
        double alpha2 = Math.max(this.minAlpha, this.alpha * (1.0 - 1.0 * numWordsSoFar / (double)this.totalWords));
        int totalNewWords = 0;
        this.trainSentence(sentence, alpha2);
        double newWords = (double)(totalNewWords += sentence.size()) + numWordsSoFar;
        double diff = Math.abs(newWords - (double)this.lastChecked);
        if (diff >= 10000.0) {
            this.lastChecked = (int)newWords;
            log.info("Words so far " + newWords + " out of " + this.totalWords);
        }
        ((AtomicLong)pair.getSecond()).getAndAdd(totalNewWords);
    }
}

