package com.datastax.insight.ml.spark.mllib.regression;

import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.IsotonicRegression;
import org.apache.spark.mllib.regression.IsotonicRegressionModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;
import scala.Tuple3;

public class IsotonicRegressor implements RDDOperator {
    public static IsotonicRegressionModel train(JavaRDD<LabeledPoint> data){
        JavaRDD<Tuple3<Double, Double, Double>> parsedData = data.map(
                new Function<LabeledPoint, Tuple3<Double, Double, Double>>() {
                    public Tuple3<Double, Double, Double> call(LabeledPoint point) {
                        return new Tuple3<>(new Double(point.label()),
                                new Double(point.features().apply(0)), 1.0);
                    }
                }
        );

        IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(parsedData);

        return model;
    }

    public static JavaPairRDD<Double, Double> predict(JavaRDD<LabeledPoint> data,IsotonicRegressionModel model){
        JavaRDD<Tuple3<Double, Double, Double>> parsedData = data.map(
                new Function<LabeledPoint, Tuple3<Double, Double, Double>>() {
                    public Tuple3<Double, Double, Double> call(LabeledPoint point) {
                        return new Tuple3<>(new Double(point.label()),
                                new Double(point.features().apply(0)), 1.0);
                    }
                }
        );

        JavaPairRDD<Double, Double> predictionAndLabel = parsedData.mapToPair(
                new PairFunction<Tuple3<Double, Double, Double>, Double, Double>() {
                    @Override
                    public Tuple2<Double, Double> call(Tuple3<Double, Double, Double> point) {
                        Double predictedLabel = model.predict(point._2());
                        return new Tuple2<>(predictedLabel, point._1());
                    }
                }
        );
        return predictionAndLabel;
    }
}
