package com.datastax.insight.ml.spark.mllib.classification;

import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import scala.Tuple2;

import java.util.HashMap;
import java.util.Map;

public class GradientBoosting implements RDDOperator {
    public static GradientBoostedTreesModel trainClassifier(JavaRDD<LabeledPoint> data,
                                                  int numIterations,int numClasses,int maxDepth){
        return train(data,"Classification",numIterations,numClasses,maxDepth);
    }

    public static GradientBoostedTreesModel trainRegressor(JavaRDD<LabeledPoint> data,
                                                            int numIterations,int numClasses,int maxDepth){
        return train(data,"Regression",numIterations,numClasses,maxDepth);
    }

    public static GradientBoostedTreesModel train(JavaRDD<LabeledPoint> data,String defaultParams,
                                                  int numIterations,int numClasses,int maxDepth){
        BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(defaultParams);
        boostingStrategy.setNumIterations(numIterations); // Note: Use more iterations in practice.
        boostingStrategy.getTreeStrategy().setNumClasses(numClasses);
        boostingStrategy.getTreeStrategy().setMaxDepth(maxDepth);
        // Empty categoricalFeaturesInfo indicates all features are continuous.
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
        boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);

        GradientBoostedTreesModel model =
                GradientBoostedTrees.train(data, boostingStrategy);

        return model;
    }

    public static JavaPairRDD<Double, Double> predict(JavaRDD<LabeledPoint> data,GradientBoostedTreesModel model){
        JavaPairRDD<Double, Double> predictionAndLabel =
                data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
                    @Override
                    public Tuple2<Double, Double> call(LabeledPoint p) {
                        return new Tuple2<>(model.predict(p.features()), p.label());
                    }
                });
        return predictionAndLabel;
    }
}
