package com.datastax.insight.ml.spark.mllib.regression;

import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import scala.Tuple2;

public class LinearRegression implements RDDOperator {
    public static LinearRegressionModel train(JavaRDD<LabeledPoint> data, int numIterations , double stepSize){
        return LinearRegressionWithSGD.train(JavaRDD.toRDD(data), numIterations, stepSize);
    }

    public static JavaRDD<Tuple2<Double, Double>> predict(JavaRDD<LabeledPoint> data,LinearRegressionModel model){
        JavaRDD<Tuple2<Double, Double>> valuesAndPreds = data.map(
                new Function<LabeledPoint, Tuple2<Double, Double>>() {
                    public Tuple2<Double, Double> call(LabeledPoint point) {
                        double prediction = model.predict(point.features());
                        return new Tuple2<>(prediction, point.label());
                    }
                }
        );
        return valuesAndPreds;
    }
}
