package com.datastax.insight.ml.spark.ml.regression;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.GBTRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 梯度提升树回归
 */
public class GradientBoostedTreeRegressor implements DataSetOperator {

    public static GBTRegressor getOperator(String labelCol,
                                           String featuresCol,
                                           Integer maxIterations,
                                           Integer maxDepth,
                                           Integer maxBins,
                                           Integer minInstancesPerNode,
                                           Double minInfoGain,
                                           Integer maxMemoryInMB,
                                           Boolean cacheNodeIds,
                                           Integer checkpointInterval,
                                           String impurity,
                                           Double subsamplingRate,
                                           Long seed,
                                           Double stepSize,
                                           String lossType) {

        GBTRegressor regressor = new GBTRegressor();

        if (!Strings.isNullOrEmpty(labelCol)) {
            regressor.setLabelCol(labelCol);
        }

        if (!Strings.isNullOrEmpty(featuresCol)) {
            regressor.setFeaturesCol(featuresCol);
        }

        if (maxIterations != null) {
            regressor.setMaxIter(maxIterations);
        }

        if (maxDepth != null) {
            regressor.setMaxDepth(maxDepth);
        }

        if (maxBins != null) {
            regressor.setMaxBins(maxBins);
        }

        if (minInstancesPerNode != null) {
            regressor.setMinInstancesPerNode(minInstancesPerNode);
        }

        if (minInfoGain != null) {
            regressor.setMinInfoGain(minInfoGain);
        }

        if (maxMemoryInMB != null) {
            regressor.setMaxMemoryInMB(maxMemoryInMB);
        }

        if (cacheNodeIds != null) {
            regressor.setCacheNodeIds(cacheNodeIds);
        }

        if (checkpointInterval != null) {
            regressor.setCheckpointInterval(checkpointInterval);
        }

        if (!Strings.isNullOrEmpty(impurity)) {
            regressor.setImpurity(impurity);
        }

        if (subsamplingRate != null) {
            regressor.setSubsamplingRate(subsamplingRate);
        }

        if (seed != null) {
            regressor.setSeed(seed);
        }

        if (stepSize != null) {
            regressor.setStepSize(stepSize);
        }

        if (!Strings.isNullOrEmpty(lossType)) {
            regressor.setLossType(lossType);
        }

        return regressor;
    }

    public static GBTRegressionModel fit(Dataset<Row> data,
                                           String labelCol,
                                           String featuresCol,
                                           Integer maxIterations,
                                           Integer maxDepth,
                                           Integer maxBins,
                                           Integer minInstancesPerNode,
                                           Double minInfoGain,
                                           Integer maxMemoryInMB,
                                           Boolean cacheNodeIds,
                                           Integer checkpointInterval,
                                           String impurity,
                                           Double subsamplingRate,
                                           Long seed,
                                           Double stepSize,
                                           String lossType) {
        GBTRegressor regressor = getOperator(labelCol,
                featuresCol,
                maxIterations,
                maxDepth,
                maxBins,
                minInstancesPerNode,
                minInfoGain,
                maxMemoryInMB,
                cacheNodeIds,
                checkpointInterval,
                impurity,
                subsamplingRate,
                seed,
                stepSize,
                lossType);
        return regressor.fit(data);
    }

    public static GBTRegressionModel fit(GBTRegressor regressor, Dataset<Row> data) {
        return regressor.fit(data);
    }

    public static Dataset<Row> transform(GBTRegressionModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
