package com.datastax.insight.ml.spark.ml.cluster;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class SimpleKMeans implements DataSetOperator {

    public static KMeans getOperator(String featuresCol,
                                     Integer k,
                                     Integer maxIterations,
                                     String initMode,
                                     Integer initSteps,
                                     Double tol,
                                     Long seed){
        KMeans kMeans=new KMeans();

        if(!Strings.isNullOrEmpty(featuresCol)) {
            kMeans.setFeaturesCol(featuresCol);
        }

        if(k!=null) {
            kMeans.setK(k);
        }

        if(maxIterations!=null) {
            kMeans.setMaxIter(maxIterations);
        }

        if(!Strings.isNullOrEmpty(initMode)) {
            kMeans.setInitMode(initMode);
        }

        if(initSteps!=null) {
            kMeans.setInitSteps(initSteps);
        }

        if(tol!=null) {
            kMeans.setTol(tol);
        }

        if(seed!=null) {
            kMeans.setSeed(seed);
        }

        return kMeans;
    }

    public static KMeansModel fit(Dataset<Row> data,
                                  String featuresCol,
                                  Integer k,
                                  Integer maxIterations,
                                  String initMode,
                                  Integer initSteps,
                                  Double tol,
                                  Long seed){
        KMeans kMeans=getOperator(featuresCol, k,maxIterations,initMode,initSteps,tol,seed);
        return kMeans.fit(data);
    }

    public static KMeansModel fit(KMeans kMeans,Dataset<Row> data){
        return kMeans.fit(data);
    }

    public static Dataset<Row> transform(KMeansModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
