package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 梯度提升决策树分类器
 */
public class GradientBoostedTreeClassifier implements DataSetOperator {

    public static GBTClassifier getOperator(String labelCol,
                                            String featuresCol,
                                            Integer maxIterations,
                                            Integer maxDepth,
                                            Integer maxBins,
                                            Integer minInstancesPerNode,
                                            Double minInfoGain,
                                            Integer maxMemoryInMB,
                                            Boolean cacheNodeIds,
                                            Integer checkpointInterval,
                                            String impurity,
                                            Double subsamplingRate,
                                            Long seed,
                                            Double stepSize,
                                            String lossType){

        GBTClassifier classifier= new GBTClassifier();

        if(!Strings.isNullOrEmpty(labelCol)) {
            classifier.setLabelCol(labelCol);
        }

        if(!Strings.isNullOrEmpty(featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }

        if(maxIterations != null) {
            classifier.setMaxIter(maxIterations);
        }

        if(maxDepth != null) {
            classifier.setMaxDepth(maxDepth);
        }

        if(maxBins != null) {
            classifier.setMaxBins(maxBins);
        }

        if(minInstancesPerNode != null) {
            classifier.setMinInstancesPerNode(minInstancesPerNode);
        }

        if(minInfoGain != null) {
            classifier.setMinInfoGain(minInfoGain);
        }

        if(maxMemoryInMB != null) {
            classifier.setMaxMemoryInMB(maxMemoryInMB);
        }

        if(cacheNodeIds != null) {
            classifier.setCacheNodeIds(cacheNodeIds);
        }

        if(checkpointInterval != null) {
            classifier.setCheckpointInterval(checkpointInterval);
        }

        if(!Strings.isNullOrEmpty(impurity)) {
            classifier.setImpurity(impurity);
        }

        if(subsamplingRate != null) {
            classifier.setSubsamplingRate(subsamplingRate);
        }

        if(subsamplingRate != null) {
            classifier.setSubsamplingRate(subsamplingRate);
        }

        if(seed != null) {
            classifier.setSeed(seed);
        }

        if(stepSize != null) {
            classifier.setStepSize(stepSize);
        }

        if(!Strings.isNullOrEmpty(lossType)) {
            classifier.setLossType(lossType);
        }

        return classifier;
    }

    public static GBTClassificationModel fit(Dataset<Row> data,
                                             String labelCol,
                                             String featuresCol,
                                             Integer maxIterations,
                                             Integer maxDepth,
                                             Integer maxBins,
                                             Integer minInstancesPerNode,
                                             Double minInfoGain,
                                             Integer maxMemoryInMB,
                                             Boolean cacheNodeIds,
                                             Integer checkpointInterval,
                                             String impurity,
                                             Double subsamplingRate,
                                             Long seed,
                                             Double stepSize,
                                             String lossType) {
        GBTClassifier classifier = getOperator(labelCol, featuresCol, maxIterations, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, subsamplingRate, seed, stepSize, lossType);
        return classifier.fit(data);
    }

    public static GBTClassificationModel fit( GBTClassifier classifier,Dataset<Row> data){
        return classifier.fit(data);
    }

    public static Dataset<Row> transform(GBTClassificationModel model,Dataset<Row> data){
        return model.transform(data);
    }
}
