package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 随机森林分类器
 */
public class RandomForestClassifierWrapper implements DataSetOperator {
    public static RandomForestClassifier getOperator(String labelCol,
                                                     String featuresCol,
                                                     Integer maxDepth,
                                                     Integer maxBins,
                                                     Integer minInstancesPerNode,
                                                     Double minInfoGain,
                                                     Integer maxMemoryInMB,
                                                     Boolean cacheNodeIds,
                                                     Integer checkpointInterval,
                                                     String impurity,
                                                     Double subsamplingRate,
                                                     Integer numTrees,
                                                     String featureSubsetStrategy){

        RandomForestClassifier classifier= new RandomForestClassifier();

        if(!Strings.isNullOrEmpty(labelCol)) {
            classifier.setLabelCol(labelCol);
        }

        if(!Strings.isNullOrEmpty(featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }

        if(maxDepth != null) {
            classifier.setMaxDepth(maxDepth);
        }

        if(maxBins != null) {
            classifier.setMaxBins(maxBins);
        }

        if(minInstancesPerNode != null) {
            classifier.setMinInstancesPerNode(minInstancesPerNode);
        }

        if(minInfoGain != null) {
            classifier.setMinInfoGain(minInfoGain);
        }

        if(maxMemoryInMB != null) {
            classifier.setMaxMemoryInMB(maxMemoryInMB);
        }

        if(cacheNodeIds != null) {
            classifier.setCacheNodeIds(cacheNodeIds);
        }

        if(checkpointInterval != null) {
            classifier.setCheckpointInterval(checkpointInterval);
        }

        if(!Strings.isNullOrEmpty(impurity)) {
            classifier.setImpurity(impurity);
        }

        if(!Strings.isNullOrEmpty(featureSubsetStrategy)) {
            classifier.setFeatureSubsetStrategy(featureSubsetStrategy);
        }

        if(subsamplingRate != null) {
            classifier.setSubsamplingRate(subsamplingRate);
        }

        if(numTrees != null) {
            classifier.setNumTrees(numTrees);
        }

        return classifier;
    }

    public static RandomForestClassificationModel fit(Dataset<Row> data,
                                                        String labelCol,
                                                        String featuresCol,
                                                        Integer maxDepth,
                                                        Integer maxBins,
                                                        Integer minInstancesPerNode,
                                                        Double minInfoGain,
                                                        Integer maxMemoryInMB,
                                                        Boolean cacheNodeIds,
                                                        Integer checkpointInterval,
                                                        String impurity,
                                                        Double subsamplingRate,
                                                        Integer numTrees,
                                                        String featureSubsetStrategy){
        RandomForestClassifier classifier=getOperator(labelCol,featuresCol,maxDepth,maxBins,
                minInstancesPerNode,minInfoGain,maxMemoryInMB,cacheNodeIds,checkpointInterval,impurity,subsamplingRate,numTrees,featureSubsetStrategy);
        return classifier.fit(data);
    }

    public static RandomForestClassificationModel fit(RandomForestClassifier classifier,Dataset<Row> data){
        return classifier.fit(data);
    }

    public static Dataset<Row> transform(RandomForestClassificationModel model,Dataset<Row> data){
        return model.transform(data);
    }
}
