package com.datastax.insight.ml.spark.mllib.recommendation.als;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;

/**
 * Created by datastax on 2017/1/5.
 */
public class RatingMetrics implements RDDOperator {

    public Metrics evaluation(MatrixFactorizationModel model, JavaRDD<Rating> data) {
        JavaPairRDD<Integer, Integer> usersproducts = data.mapToPair(d-> new Tuple2<>(d.user(), d.product()));

        JavaPairRDD<Tuple2<Integer, Integer>, Double> predictions = model.predict(usersproducts)
                .mapToPair(d-> new Tuple2<>(new Tuple2<>(d.user(), d.product()), d.rating()));
        JavaPairRDD<Tuple2<Integer, Integer>, Double> rawData = data
                .mapToPair(d-> new Tuple2<>(new Tuple2<>(d.user(), d.product()), d.rating()));

        JavaRDD<Tuple2<Object, Object>> predictedAndTrue = rawData.join(predictions)
                .map(d->new Tuple2<>(d._2()._1(), d._2()._2()));

        RegressionMetrics regressionMetrics = new RegressionMetrics(predictedAndTrue.rdd());

        Metrics metrics = new Metrics();
        metrics.getIndicator().setMae(regressionMetrics.meanAbsoluteError());
        metrics.getIndicator().setMse(regressionMetrics.meanSquaredError());
        metrics.getIndicator().setRmse(regressionMetrics.rootMeanSquaredError());
        metrics.getIndicator().setR2(regressionMetrics.r2());

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }
}
