package com.datastax.insight.ml.spark.mllib.regression;

import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;

public class LogisticRegression implements RDDOperator {
    public static LogisticRegressionModel train(JavaRDD<LabeledPoint> data,int numClasses){
        LogisticRegressionModel model = new LogisticRegressionWithLBFGS().setNumClasses(numClasses).run(data.rdd());
        return model;
    }

    public static JavaRDD<Tuple2<Object, Object>> predict(JavaRDD<LabeledPoint> data, LogisticRegressionModel model){
        JavaRDD<Tuple2<Object, Object>> predictionAndLabels = data.map(
                new Function<LabeledPoint, Tuple2<Object, Object>>() {
                    public Tuple2<Object, Object> call(LabeledPoint p) {
                        Double prediction = model.predict(p.features());
                        return new Tuple2<Object, Object>(prediction, p.label());
                    }
                }
        );
        return predictionAndLabels;
    }
}
