/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.clustering;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.clustering.DistributedLDAModel;
import org.apache.spark.mllib.clustering.LDA;
import org.apache.spark.mllib.clustering.LDAModel;
import org.apache.spark.mllib.clustering.LDAOptimizer;
import org.apache.spark.mllib.clustering.LDASuite;
import org.apache.spark.mllib.clustering.LocalLDAModel;
import org.apache.spark.mllib.clustering.OnlineLDAOptimizer;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.junit.Assert;
import org.junit.Test;
import scala.Tuple2;
import scala.Tuple3;

public class JavaLDASuite
extends SharedSparkSession {
    private static int tinyK = LDASuite.tinyK();
    private static int tinyVocabSize = LDASuite.tinyVocabSize();
    private static Matrix tinyTopics = LDASuite.tinyTopics();
    private static Tuple2<int[], double[]>[] tinyTopicDescription = LDASuite.tinyTopicDescription();
    private JavaPairRDD<Long, Vector> corpus;
    private LocalLDAModel toyModel = LDASuite.toyModel();
    private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();

    @Override
    public void setUp() throws IOException {
        super.setUp();
        ArrayList<Tuple2> tinyCorpus2 = new ArrayList<Tuple2>();
        for (int i = 0; i < LDASuite.tinyCorpus().length; ++i) {
            tinyCorpus2.add(new Tuple2((Object)((Long)LDASuite.tinyCorpus()[i]._1()), LDASuite.tinyCorpus()[i]._2()));
        }
        JavaRDD tmpCorpus = this.jsc.parallelize(tinyCorpus2, 2);
        this.corpus = JavaPairRDD.fromJavaRDD((JavaRDD)tmpCorpus);
    }

    @Test
    public void localLDAModel() {
        Matrix topics = LDASuite.tinyTopics();
        double[] topicConcentration = new double[topics.numRows()];
        Arrays.fill(topicConcentration, 1.0 / (double)topics.numRows());
        LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense((double[])topicConcentration), 1.0, 100.0);
        Assert.assertEquals((long)model.k(), (long)tinyK);
        Assert.assertEquals((long)model.vocabSize(), (long)tinyVocabSize);
        Assert.assertEquals((Object)model.topicsMatrix(), (Object)tinyTopics);
        Tuple2[] fullTopicSummary = model.describeTopics();
        Assert.assertEquals((long)fullTopicSummary.length, (long)tinyK);
        for (int i = 0; i < fullTopicSummary.length; ++i) {
            Assert.assertArrayEquals((int[])((int[])fullTopicSummary[i]._1()), (int[])((int[])tinyTopicDescription[i]._1()));
            Assert.assertArrayEquals((double[])((double[])fullTopicSummary[i]._2()), (double[])((double[])tinyTopicDescription[i]._2()), (double)1.0E-5);
        }
    }

    @Test
    public void distributedLDAModel() {
        int k = 3;
        double topicSmoothing = 1.2;
        double termSmoothing = 1.2;
        LDA lda = new LDA();
        lda.setK(k).setDocConcentration(topicSmoothing).setTopicConcentration(termSmoothing).setMaxIterations(5).setSeed(12345L);
        DistributedLDAModel model = (DistributedLDAModel)lda.run(this.corpus);
        LocalLDAModel localModel = model.toLocal();
        Assert.assertEquals((long)model.k(), (long)k);
        Assert.assertEquals((long)localModel.k(), (long)k);
        Assert.assertEquals((long)model.vocabSize(), (long)tinyVocabSize);
        Assert.assertEquals((long)localModel.vocabSize(), (long)tinyVocabSize);
        Assert.assertEquals((Object)model.topicsMatrix(), (Object)localModel.topicsMatrix());
        Tuple2[] roundedTopicSummary = model.describeTopics();
        Assert.assertEquals((long)roundedTopicSummary.length, (long)k);
        Tuple2[] roundedLocalTopicSummary = localModel.describeTopics();
        Assert.assertEquals((long)roundedLocalTopicSummary.length, (long)k);
        Assert.assertTrue((model.logLikelihood() < 0.0 ? 1 : 0) != 0);
        Assert.assertTrue((model.logPrior() < 0.0 ? 1 : 0) != 0);
        JavaPairRDD topicDistributions = model.javaTopicDistributions();
        JavaPairRDD nonEmptyCorpus = this.corpus.filter((Function)new Function<Tuple2<Long, Vector>, Boolean>(){

            public Boolean call(Tuple2<Long, Vector> tuple2) {
                return Vectors.norm((Vector)((Vector)tuple2._2()), (double)1.0) != 0.0;
            }
        });
        Assert.assertEquals((long)topicDistributions.count(), (long)nonEmptyCorpus.count());
        Tuple3 topTopics = (Tuple3)model.javaTopTopicsPerDocument(3).first();
        Long docId = (Long)topTopics._1();
        int[] topicIndices = (int[])topTopics._2();
        double[] topicWeights = (double[])topTopics._3();
        Assert.assertEquals((long)3L, (long)topicIndices.length);
        Assert.assertEquals((long)3L, (long)topicWeights.length);
        Tuple3 topicAssignment = (Tuple3)model.javaTopicAssignments().first();
        Long docId2 = (Long)topicAssignment._1();
        int[] termIndices2 = (int[])topicAssignment._2();
        int[] topicIndices2 = (int[])topicAssignment._3();
        Assert.assertEquals((long)termIndices2.length, (long)topicIndices2.length);
    }

    @Test
    public void onlineOptimizerCompatibility() {
        int k = 3;
        double topicSmoothing = 1.2;
        double termSmoothing = 1.2;
        OnlineLDAOptimizer op = new OnlineLDAOptimizer().setTau0(1024.0).setKappa(0.51).setGammaShape(1.0E40).setMiniBatchFraction(0.5);
        LDA lda = new LDA();
        lda.setK(k).setDocConcentration(topicSmoothing).setTopicConcentration(termSmoothing).setMaxIterations(5).setSeed(12345L).setOptimizer((LDAOptimizer)op);
        LDAModel model = lda.run(this.corpus);
        Assert.assertEquals((long)model.k(), (long)k);
        Assert.assertEquals((long)model.vocabSize(), (long)tinyVocabSize);
        Tuple2[] roundedTopicSummary = model.describeTopics();
        Assert.assertEquals((long)roundedTopicSummary.length, (long)k);
        Tuple2[] roundedLocalTopicSummary = model.describeTopics();
        Assert.assertEquals((long)roundedLocalTopicSummary.length, (long)k);
    }

    @Test
    public void localLdaMethods() {
        JavaRDD docs2 = this.jsc.parallelize(this.toyData, 2);
        JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD((JavaRDD)docs2);
        Assert.assertEquals((long)this.toyModel.topicDistributions(pairedDocs).count(), (long)pairedDocs.count());
        double logPerplexity = this.toyModel.logPerplexity(pairedDocs);
        ArrayList<Tuple2> docsSingleWord = new ArrayList<Tuple2>();
        docsSingleWord.add(new Tuple2((Object)0L, (Object)Vectors.dense((double)1.0, (double[])new double[]{0.0, 0.0})));
        JavaPairRDD single = JavaPairRDD.fromJavaRDD((JavaRDD)this.jsc.parallelize(docsSingleWord));
        double logLikelihood = this.toyModel.logLikelihood(single);
    }
}

