package com.datastax.insight.ml.spark.ml.cluster;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.clustering.BisectingKMeans;
import org.apache.spark.ml.clustering.BisectingKMeansModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 二分KMeans
 */
public class BisectingKMeansHandler implements DataSetOperator {

    public static BisectingKMeans getOperator(String featuresCol,
                                              Integer k,
                                              Integer maxIterations,
                                              Long seed,
                                              Double minDivisibleClusterSize){

        BisectingKMeans kMeans = new BisectingKMeans();

        if(!Strings.isNullOrEmpty(featuresCol)) {
            kMeans.setFeaturesCol(featuresCol);
        }

        if(k!=null) {
            kMeans.setK(k);
        }

        if(maxIterations!=null) {
            kMeans.setMaxIter(maxIterations);
        }

        if(seed!=null) {
            kMeans.setSeed(seed);
        }

        if(minDivisibleClusterSize!=null) {
            kMeans.setMinDivisibleClusterSize(minDivisibleClusterSize);
        }

        return kMeans;
    }

    public static BisectingKMeansModel fit(Dataset<Row> data,
                                           String featuresCol,
                                           Integer k,
                                           Integer maxIterations,
                                           Long seed,
                                           Double minDivisibleClusterSize){
        BisectingKMeans kMeans=getOperator(featuresCol, k,maxIterations,seed,minDivisibleClusterSize);
        return kMeans.fit(data);
    }

    public static BisectingKMeansModel fit(BisectingKMeans kMeans,Dataset<Row> data){
        return kMeans.fit(data);
    }

    public static Dataset<Row> transform(BisectingKMeansModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
