package com.datastax.insight.ml.spark.ml.regression;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 随机森林回归
 */
public class RandomForestRegressionWrapper implements DataSetOperator {

    public static RandomForestRegressor getOperator(String labelCol,
                                                    String featuresCol,
                                                    Integer maxDepth,
                                                    Integer maxBins,
                                                    Integer minInstancesPerNode,
                                                    Double minInfoGain,
                                                    Integer maxMemoryInMB,
                                                    Boolean cacheNodeIds,
                                                    Integer checkpointInterval,
                                                    String impurity,
                                                    Double subsamplingRate,
                                                    Integer numTrees,
                                                    String featureSubsetStrategy) {

        RandomForestRegressor regressor = new RandomForestRegressor();

        if (!Strings.isNullOrEmpty(labelCol)) {
            regressor.setLabelCol(labelCol);
        }

        if (!Strings.isNullOrEmpty(featuresCol)) {
            regressor.setFeaturesCol(featuresCol);
        }

        if (maxDepth != null) {
            regressor.setMaxDepth(maxDepth);
        }

        if (maxBins != null) {
            regressor.setMaxBins(maxBins);
        }

        if (minInstancesPerNode != null) {
            regressor.setMinInstancesPerNode(minInstancesPerNode);
        }

        if (minInfoGain != null) {
            regressor.setMinInfoGain(minInfoGain);
        }

        if (maxMemoryInMB != null) {
            regressor.setMaxMemoryInMB(maxMemoryInMB);
        }

        if (cacheNodeIds != null) {
            regressor.setCacheNodeIds(cacheNodeIds);
        }

        if (checkpointInterval != null) {
            regressor.setCheckpointInterval(checkpointInterval);
        }

        if (!Strings.isNullOrEmpty(impurity)) {
            regressor.setImpurity(impurity);
        }

        if (subsamplingRate != null) {
            regressor.setSubsamplingRate(subsamplingRate);
        }

        if (numTrees != null) {
            regressor.setNumTrees(numTrees);
        }

        if (!Strings.isNullOrEmpty(featureSubsetStrategy)) {
            regressor.setFeatureSubsetStrategy(featureSubsetStrategy);
        }

        return regressor;
    }

    public static RandomForestRegressionModel fit(Dataset<Row> data,
                                                  String labelCol,
                                                  String featuresCol,
                                                  Integer maxDepth,
                                                  Integer maxBins,
                                                  Integer minInstancesPerNode,
                                                  Double minInfoGain,
                                                  Integer maxMemoryInMB,
                                                  Boolean cacheNodeIds,
                                                  Integer checkpointInterval,
                                                  String impurity,
                                                  Double subsamplingRate,
                                                  Integer numTrees,
                                                  String featureSubsetStrategy) {
        RandomForestRegressor regressor = getOperator(labelCol,
                featuresCol,
                maxDepth,
                maxBins,
                minInstancesPerNode,
                minInfoGain,
                maxMemoryInMB,
                cacheNodeIds,
                checkpointInterval,
                impurity,
                subsamplingRate,
                numTrees,
                featureSubsetStrategy);
        return regressor.fit(data);
    }

    public static RandomForestRegressionModel fit(RandomForestRegressor regressor, Dataset<Row> data) {
        return regressor.fit(data);
    }

    public static Dataset<Row> transform(RandomForestRegressionModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
