/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.api.java.function.Function;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecChange;
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecFuncCall;
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecParam;
import org.nd4j.common.primitives.Triple;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

@Deprecated
public class SentenceBatch
implements Function<Word2VecFuncCall, Word2VecChange> {
    private AtomicLong nextRandom = new AtomicLong(5L);

    public Word2VecChange call(Word2VecFuncCall sentence) throws Exception {
        Word2VecParam param = (Word2VecParam)sentence.getParam().getValue();
        ArrayList<Triple<Integer, Integer, Integer>> changed = new ArrayList<Triple<Integer, Integer, Integer>>();
        double alpha = Math.max(param.getMinAlpha(), param.getAlpha() * (1.0 - 1.0 * (double)sentence.getWordsSeen().longValue() / (double)param.getTotalWords()));
        this.trainSentence(param, sentence.getSentence(), alpha, changed);
        return new Word2VecChange(changed, param);
    }

    public void trainSentence(Word2VecParam param, List<VocabWord> sentence, double alpha, List<Triple<Integer, Integer, Integer>> changed) {
        if (sentence != null && !sentence.isEmpty()) {
            for (int i = 0; i < sentence.size(); ++i) {
                VocabWord vocabWord = sentence.get(i);
                if (vocabWord == null || !vocabWord.getWord().endsWith("STOP")) continue;
                this.nextRandom.set(this.nextRandom.get() * 25214903917L + 11L);
                this.skipGram(param, i, sentence, (int)this.nextRandom.get() % param.getWindow(), alpha, changed);
            }
        }
    }

    public void skipGram(Word2VecParam param, int i, List<VocabWord> sentence, int b, double alpha, List<Triple<Integer, Integer, Integer>> changed) {
        VocabWord word = sentence.get(i);
        int window = param.getWindow();
        if (word != null && !sentence.isEmpty()) {
            int end = window * 2 + 1 - b;
            for (int a = b; a < end; ++a) {
                int c;
                if (a == window || (c = i - window + a) < 0 || c >= sentence.size()) continue;
                VocabWord lastWord = sentence.get(c);
                this.iterateSample(param, word, lastWord, alpha, changed);
            }
        }
    }

    public void iterateSample(Word2VecParam param, VocabWord w1, VocabWord w2, double alpha, List<Triple<Integer, Integer, Integer>> changed) {
        INDArray neu1e;
        INDArray l1;
        block7: {
            if (w2 == null || w2.getIndex() < 0 || w1.getIndex() == w2.getIndex() || w1.getWord().equals("STOP") || w2.getWord().equals("STOP") || w1.getWord().equals("UNK") || w2.getWord().equals("UNK")) {
                return;
            }
            int vectorLength = param.getVectorLength();
            InMemoryLookupTable weights = param.getWeights();
            boolean useAdaGrad = param.isUseAdaGrad();
            double negative = param.getNegative();
            INDArray table = param.getTable();
            double[] expTable = (double[])param.getExpTable().getValue();
            double MAX_EXP = 6.0;
            int numWords = param.getNumWords();
            l1 = weights.vector(w2.getWord());
            neu1e = Nd4j.create((int)vectorLength);
            for (int i = 0; i < w1.getCodeLength(); ++i) {
                byte code = (Byte)w1.getCodes().get(i);
                int point = (Integer)w1.getPoints().get(i);
                INDArray syn1 = weights.getSyn1().slice((long)point);
                double dot = Nd4j.getBlasWrapper().level1().dot(syn1.length(), 1.0, l1, syn1);
                if (dot < -MAX_EXP || dot >= MAX_EXP) continue;
                int idx = (int)((dot + MAX_EXP) * ((double)expTable.length / MAX_EXP / 2.0));
                double f = expTable[idx];
                double g = ((double)(1 - code) - f) * (useAdaGrad ? w1.getGradient(i, alpha, alpha) : alpha);
                Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, syn1, neu1e);
                Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, l1, syn1);
                changed.add((Triple<Integer, Integer, Integer>)new Triple((Object)point, (Object)w1.getIndex(), (Object)-1));
            }
            changed.add((Triple<Integer, Integer, Integer>)new Triple((Object)w1.getIndex(), (Object)w2.getIndex(), (Object)-1));
            if (!(negative > 0.0)) break block7;
            int target = w1.getIndex();
            INDArray syn1Neg = weights.getSyn1Neg().slice((long)target);
            int d = 0;
            while ((double)d < negative + 1.0) {
                block10: {
                    int label;
                    block9: {
                        block8: {
                            if (d != 0) break block8;
                            label = 1;
                            break block9;
                        }
                        this.nextRandom.set(this.nextRandom.get() * 25214903917L + 11L);
                        target = table.getInt(new int[]{(int)(this.nextRandom.get() >> 16) % (int)table.length()});
                        if (target == 0) {
                            target = (int)this.nextRandom.get() % (numWords - 1) + 1;
                        }
                        if (target == w1.getIndex()) break block10;
                        label = 0;
                    }
                    double f = Nd4j.getBlasWrapper().dot(l1, syn1Neg);
                    double g = f > MAX_EXP ? (useAdaGrad ? w1.getGradient(target, (double)(label - 1), alpha) : (double)(label - 1) * alpha) : (f < -MAX_EXP ? (double)label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha) : (useAdaGrad ? w1.getGradient(target, (double)label - expTable[(int)((f + MAX_EXP) * ((double)expTable.length / MAX_EXP / 2.0))], alpha) : ((double)label - expTable[(int)((f + MAX_EXP) * ((double)expTable.length / MAX_EXP / 2.0))]) * alpha));
                    Nd4j.getBlasWrapper().level1().axpy(l1.length(), g, neu1e, l1);
                    Nd4j.getBlasWrapper().level1().axpy(l1.length(), g, syn1Neg, l1);
                    changed.add((Triple<Integer, Integer, Integer>)new Triple((Object)-1, (Object)-1, (Object)label));
                }
                ++d;
            }
        }
        Nd4j.getBlasWrapper().level1().axpy(l1.length(), 1.0, neu1e, l1);
    }
}

