package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 决策树分类器
 */
public class DecisionTreeClassifierWrapper implements DataSetOperator {

    public static DecisionTreeClassifier getOperator(String labelCol,
                                                     String featuresCol,
                                                     Integer maxDepth,
                                                     Integer maxBins,
                                                     Integer minInstancesPerNode,
                                                     Double minInfoGain,
                                                     Integer maxMemoryInMB,
                                                     Boolean cacheNodeIds,
                                                     Integer checkpointInterval,
                                                     String impurity){

        DecisionTreeClassifier classifier= new DecisionTreeClassifier();

        if(!Strings.isNullOrEmpty(labelCol)) {
            classifier.setLabelCol(labelCol);
        }

        if(!Strings.isNullOrEmpty(featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }

        if(maxDepth != null) {
            classifier.setMaxDepth(maxDepth);
        }

        if(maxBins != null) {
            classifier.setMaxBins(maxBins);
        }

        if(minInstancesPerNode != null) {
            classifier.setMinInstancesPerNode(minInstancesPerNode);
        }

        if(minInfoGain != null) {
            classifier.setMinInfoGain(minInfoGain);
        }

        if(maxMemoryInMB != null) {
            classifier.setMaxMemoryInMB(maxMemoryInMB);
        }

        if(cacheNodeIds != null) {
            classifier.setCacheNodeIds(cacheNodeIds);
        }

        if(checkpointInterval != null) {
            classifier.setCheckpointInterval(checkpointInterval);
        }

        if(!Strings.isNullOrEmpty(impurity)) {
            classifier.setImpurity(impurity);
        }

        return classifier;
    }

    public static DecisionTreeClassificationModel fit(Dataset<Row> data,
                                                      String labelCol,
                                                      String featuresCol,
                                                      Integer maxDepth,
                                                      Integer maxBins,
                                                      Integer minInstancesPerNode,
                                                      Double minInfoGain,
                                                      Integer maxMemoryInMB,
                                                      Boolean cacheNodeIds,
                                                      Integer checkpointInterval,
                                                      String impurity){
        DecisionTreeClassifier classifier=getOperator(labelCol,featuresCol,maxDepth,maxBins,
                minInstancesPerNode,minInfoGain,maxMemoryInMB,cacheNodeIds,checkpointInterval,impurity);
        return classifier.fit(data);
    }

    public static DecisionTreeClassificationModel fit(DecisionTreeClassifier classifier,Dataset<Row> data){
        return classifier.fit(data);
    }

    public static Dataset<Row> transform(DecisionTreeClassificationModel model,Dataset<Row> data){
        return model.transform(data);
    }
}
