package com.datastax.insight.ml.spark.mllib.cluster;

import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.clustering.LDA;
import org.apache.spark.mllib.clustering.LDAModel;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import scala.Tuple2;

public class LDACluster implements RDDOperator {
    public static LDAModel train(JavaRDD<Vector> data,int numClasses,boolean cached){
        JavaPairRDD<Long, Vector> corpus =
                JavaPairRDD.fromJavaRDD(data.zipWithIndex().map(
                        new Function<Tuple2<Vector, Long>, Tuple2<Long, Vector>>() {
                            public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
                                return doc_id.swap();
                            }
                        }
                        )
                );
        if(cached) {
            corpus.cache();
        }

        // Cluster the documents into three topics using LDA
        LDAModel ldaModel = new LDA().setK(numClasses).run(corpus);

//        LogUtil logUtil=new LogUtil(LDACluster.class);
//        logUtil.logUserOutputStart("train");

        System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
                + " words):");
        Matrix topics = ldaModel.topicsMatrix();
        for (int topic = 0; topic < 3; topic++) {
            System.out.print("Topic " + topic + ":");
            for (int word = 0; word < ldaModel.vocabSize(); word++) {
                System.out.print(" " + topics.apply(word, topic));
            }
            System.out.println();
        }

        //logUtil.logUserOutputEnd("train");

        return ldaModel;
    }
}
