package com.datastax.insight.ml.spark.ml.cluster;

import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.Consts;
import com.google.common.base.Strings;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class LDACluster implements DataSetOperator {

    public static LDA getOperator(String featuresCol,
                                  Integer k,
                                  Integer maxIterations,
                                  Long seed,
                                  Integer checkpointInterval,
                                  String docConcentration,
                                  Double topicConcentration,
                                  String optimizer,
                                  String topicDistributionCol,
                                  Double learningOffset,
                                  Double learningDecay,
                                  Double subsamplingRate,
                                  Boolean optimizeDocConcentration,
                                  Boolean keepLastCheckpoint){

        LDA lda=new LDA();

        if(!Strings.isNullOrEmpty(featuresCol)){
            lda.setFeaturesCol(featuresCol);
        }

        if(k != null) {
            lda.setK(k);
        }

        if(maxIterations != null) {
            lda.setMaxIter(maxIterations);
        }

        if(seed != null) {
            lda.setSeed(seed);
        }

        if(checkpointInterval != null) {
            lda.setCheckpointInterval(checkpointInterval);
        }

        if (!Strings.isNullOrEmpty(docConcentration)) {
            String[] dcs = docConcentration.split(Consts.DELIMITER);
            double[] dc = new double[dcs.length];
            for (int i = 0; i < dc.length; i++) {
                dc[i] = Double.parseDouble(dcs[i]);
            }

            lda.setDocConcentration(dc);
        }

        if(topicConcentration != null) {
            lda.setTopicConcentration(topicConcentration);
        }

        if(!Strings.isNullOrEmpty(optimizer)){
            lda.setOptimizer(optimizer);
        }

        if(!Strings.isNullOrEmpty(topicDistributionCol)){
            lda.setTopicDistributionCol(topicDistributionCol);
        }

        if(learningOffset != null) {
            lda.setLearningOffset(learningOffset);
        }

        if(learningDecay != null) {
            lda.setLearningDecay(learningDecay);
        }

        if(subsamplingRate != null) {
            lda.setSubsamplingRate(subsamplingRate);
        }

        if(optimizeDocConcentration != null) {
            lda.setOptimizeDocConcentration(optimizeDocConcentration);
        }

        if(keepLastCheckpoint != null) {
            lda.setKeepLastCheckpoint(keepLastCheckpoint);
        }

        return lda;
    }

    public static LDAModel fit(Dataset<Row> data,
                               String featuresCol,
                               Integer k,
                               Integer maxIterations,
                               Long seed,
                               Integer checkpointInterval,
                               String docConcentration,
                               Double topicConcentration,
                               String optimizer,
                               String topicDistributionCol,
                               Double learningOffset,
                               Double learningDecay,
                               Double subsamplingRate,
                               Boolean optimizeDocConcentration,
                               Boolean keepLastCheckpoint){

        LDA lda=getOperator(featuresCol,k,maxIterations,seed,checkpointInterval,docConcentration,topicConcentration,optimizer,topicDistributionCol,
                learningOffset,learningDecay,subsamplingRate,optimizeDocConcentration,keepLastCheckpoint);
        return lda.fit(data);
    }

    public static LDAModel fit(LDA lda,Dataset<Row> data){
        return lda.fit(data);
    }

    public static Dataset<Row> transform(LDAModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
