package com.datastax.insight.ml.spark.mllib.recommendation.als;

import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;

public class ALSWrapper implements RDDOperator {
    private static final double DEFAULT_LAMBDA=0.01;

    public static MatrixFactorizationModel train(JavaRDD<Rating> ratings,
                                                 int rank,
                                                 int iterations,
                                                 double lambda,
                                                 int blocks){
        return ALS.train(ratings.rdd(),rank,iterations,lambda,blocks);
    }

    public static MatrixFactorizationModel train(JavaRDD<Rating> ratings,
                                                 int rank,
                                                 int iterations,
                                                 int blocks){
        return train(ratings,rank,iterations,DEFAULT_LAMBDA,blocks);
    }

    public static MatrixFactorizationModel train(String ratingFilePath,
                                                 int rank,
                                                 int iterations,
                                                 int blocks){
        JavaRDD<Rating> ratings= RatingRDDLoader.fromTextFile(ratingFilePath);
        return train(ratings,rank,iterations,DEFAULT_LAMBDA,blocks);
    }

    public static MatrixFactorizationModel train(String ratingFilePath,
                                                 String delimiter,
                                                 String[] propOrders,
                                                 int rank,
                                                 int iterations,
                                                 int blocks){
        JavaRDD<Rating> ratings= RatingRDDLoader.fromTextFile(ratingFilePath,delimiter,propOrders);
        return train(ratings,rank,iterations,DEFAULT_LAMBDA,blocks);
    }

    public MatrixFactorizationModel loadModel(String path){
        SparkContext sc= SparkContextBuilder.getContext();
        MatrixFactorizationModel model = MatrixFactorizationModel.load(sc,path);
        sc.stop();
        return model;
    }

    public static void saveModel(MatrixFactorizationModel model,String path){
        SparkContext sc= SparkContextBuilder.getContext();
        model.save(sc,path);
        sc.stop();
    }

    public double predict(MatrixFactorizationModel model,int user,int product){
        double rank= model.predict(user,product);
        System.out.println("用户【"+user+"】对商品【"+product+"】的预测评分为：【"+rank+"】");
        return rank;
    }

    public JavaRDD<Rating> predict(MatrixFactorizationModel model,JavaPairRDD<Integer,Integer> userProducts){
        return model.predict(userProducts);
    }

    public Rating[] recommendProducts(MatrixFactorizationModel model,int user,int num){
        return model.recommendProducts(user,num);
    }
    public Rating[] recommendUsers(MatrixFactorizationModel model,int product,int num){
        return model.recommendUsers(product, num);
    }

    public RDD<Tuple2<Object, Rating[]>> recommendProductsForUsers(MatrixFactorizationModel model,int num){
        return model.recommendProductsForUsers(num);
    }

    public RDD<Tuple2<Object,Rating[]>> recommendUsersForProducts(MatrixFactorizationModel model,int num){
        return model.recommendUsersForProducts(num);
    }
}