package com.datastax.insight.ml.spark.mllib.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.RankingMetrics;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.List;

/**
 * Created by huangping on 17-1-16.
 */
public class RankingMetricsWrapper implements RDDOperator {

    public Metrics evaluation(MatrixFactorizationModel model, JavaRDD<Rating> data) {

        //get top 10 recommendations for every user and scala ratings from 0 to 1
        JavaRDD<Tuple2<Object, Rating[]>> userRecs = model.recommendProductsForUsers(10).toJavaRDD();
        JavaRDD<Tuple2<Object, Rating[]>> userRecsScaled = userRecs.map((Function<Tuple2<Object, Rating[]>, Tuple2<Object, Rating[]>>) t -> {
            Rating[] scaledRating = new Rating[t._2().length];
            for (int i = 0; i < scaledRating.length; i++) {
                double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0);
                scaledRating[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating);
            }
            return new Tuple2<>(t._1(), scaledRating);
        });
        JavaPairRDD<Object, Rating[]> userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled);

        //map ratings to 1 or 0, indicating a product that should be recommended
        JavaRDD<Rating> binarizedRatings = data.map(r -> new Rating(r.user(), r.product(), r.rating() > 0 ? 1.0 : 0.0));

        //group ratings by common user
        JavaPairRDD<Object, Iterable<Rating>> userProducts = binarizedRatings.groupBy((Function<Rating, Object>) r -> r.user());

        //get true relevant documents from user ratings
        JavaPairRDD<Object, List<Integer>> userProductsList = userProducts.mapValues((Function<Iterable<Rating>, List<Integer>>) t -> {
            List<Integer> products = new ArrayList<>();
            for (Rating r : t) {
                if (r.rating() > 0.0) {
                    products.add(r.product());
                }
            }
            return products;
        });

        //extract the product id from each recommendation
        JavaPairRDD<Object, List<Integer>> userRecommendedList = userRecommended.mapValues((Function<Rating[], List<Integer>>) t -> {
            List<Integer> products = new ArrayList<>();
            for (Rating r : t) {
                products.add(r.product());
            }
            return products;
        });
        JavaRDD<Tuple2<List<Integer>, List<Integer>>> relevantDocs = userProductsList.join(userRecommendedList).values();

        //instantiate the metrics object
        RankingMetrics<Integer> rankingMetrics = RankingMetrics.of(relevantDocs);
        Metrics metrics = new Metrics();

        metrics.getIndicator().setMeanAveragePrecision(rankingMetrics.meanAveragePrecision());

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }
}
