package com.datastax.insight.ml.spark.ml.cluster;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.clustering.GaussianMixture;
import org.apache.spark.ml.clustering.GaussianMixtureModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 混合高斯
 */
public class GaussianMixturer implements DataSetOperator {
    public static GaussianMixture getOperator(String featuresCol,
                                              Integer k,
                                              Integer maxIterations,
                                              Double tol,
                                              Long seed){

        GaussianMixture gussian=new GaussianMixture();

        if(!Strings.isNullOrEmpty(featuresCol)) {
            gussian.setFeaturesCol(featuresCol);
        }

        if(k!=null) {
            gussian.setK(k);
        }

        if(maxIterations!=null) {
            gussian.setMaxIter(maxIterations);
        }

        if(tol!=null) {
            gussian.setTol(tol);
        }

        if(seed!=null) {
            gussian.setSeed(seed);
        }

        return gussian;
    }

    public static GaussianMixtureModel fit(Dataset<Row> data,
                                           String featuresCol,
                                           Integer k,
                                           Integer maxIterations,
                                           Double tol,
                                           Long seed){
        GaussianMixture gussian=getOperator(featuresCol, k,maxIterations,tol,seed);
        return gussian.fit(data);
    }

    public static GaussianMixtureModel fit(GaussianMixture gussian,Dataset<Row> data){
        return gussian.fit(data);
    }

    public static Dataset<Row> transform(GaussianMixtureModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
