package com.datastax.insight.ml.spark.mllib.classification;

import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.util.LogUtil;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import scala.Tuple2;

import java.util.HashMap;
import java.util.List;

public class RandomForestClassifier implements RDDOperator {
    private static LogUtil logUtil=new LogUtil(RandomForestClassifier.class);

    public static RandomForestModel train(JavaRDD<LabeledPoint> data,int numClasses,
                                          int numTrees,String featureSubsetStrategy,int maxDepth,int maxBins,int seed){
        return train(data,numClasses,numTrees,featureSubsetStrategy,"gini",maxDepth,maxBins,seed);
    }

    public static RandomForestModel train(JavaRDD<LabeledPoint> data,int numClasses,
                                          int numTrees,String featureSubsetStrategy,String impurity,int maxDepth,int maxBins,int seed){

        HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
        RandomForestModel model = RandomForest.trainClassifier(data, numClasses,
                categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
                seed);

        return model;
    }

    public static JavaPairRDD<Double, Double> predict(JavaRDD<LabeledPoint> data,RandomForestModel model){

        JavaPairRDD<Double, Double> predictionAndLabel =
                data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
                    @Override
                    public Tuple2<Double, Double> call(LabeledPoint p) {
                        Tuple2<Double,Double> tuple2= new Tuple2<>(model.predict(p.features()), p.label());
                        //System.out.println(p.features().toString()+"===>"+p.label());
                        return tuple2;
                    }
                });

        printPredictions(predictionAndLabel);

        Double testErr =
                1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
                    @Override
                    public Boolean call(Tuple2<Double, Double> pl) {
                        return !pl._1().equals(pl._2());
                    }
                }).count() / data.count();

        logUtil.logUserOutputStart("predict");
        System.out.println("Test Error: " + testErr);
        logUtil.logUserOutputEnd("predict");

        return predictionAndLabel;
    }

    private static void printPredictions(JavaPairRDD<Double, Double> predictionAndLabel){
        List<Tuple2<Double,Double>> list=predictionAndLabel.collect();
        logUtil.logUserOutputStart("predict");

        for(Tuple2<Double,Double> tuple2 : list){
            System.out.println(tuple2.toString());
        }

        logUtil.logUserOutputEnd("predict");
    }
}
