package com.datastax.insight.ml.spark.ml.recommendation.als;
import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * ALS推荐
 */
public class ALSWrapper implements DataSetOperator {

    public static ALS getOperator(String userCol,
                                  String itemCol,
                                  String ratingCol,
                                  Integer rank,
                                  Integer numUserBlocks,
                                  Integer numItemBlocks,
                                  Boolean implicitPrefs,
                                  Double alpha,
                                  Integer maxIterations,
                                  Double regParam,
                                  Boolean nonnegative,
                                  Integer checkpointInterval,
                                  Long seed,
                                  String intermediateStorageLevel,
                                  String finalStorageLevel){

        ALS als=new ALS();

        if(!Strings.isNullOrEmpty(userCol)) {
            als.setUserCol(userCol);
        }

        if(!Strings.isNullOrEmpty(itemCol)) {
            als.setItemCol(itemCol);
        }

        if(!Strings.isNullOrEmpty(ratingCol)) {
            als.setRatingCol(ratingCol);
        }

        if(rank != null) {
            als.setRank(rank);
        }

        if(numUserBlocks != null) {
            als.setNumBlocks(numUserBlocks);
        }

        if(numItemBlocks != null) {
            als.setNumItemBlocks(numItemBlocks);
        }

        if(implicitPrefs != null) {
            als.setImplicitPrefs(implicitPrefs);
        }

        if(alpha != null) {
            als.setAlpha(alpha);
        }

        if(maxIterations != null) {
            als.setMaxIter(maxIterations);
        }

        if(regParam != null) {
            als.setRegParam(regParam);
        }

        if(nonnegative != null) {
            als.setNonnegative(nonnegative);
        }

        if(checkpointInterval != null) {
            als.setCheckpointInterval(checkpointInterval);
        }

        if(seed != null) {
            als.setSeed(seed);
        }

        if(!Strings.isNullOrEmpty(intermediateStorageLevel)) {
            als.setIntermediateStorageLevel(intermediateStorageLevel);
        }

        if(!Strings.isNullOrEmpty(finalStorageLevel)) {
            als.setFinalStorageLevel(finalStorageLevel);
        }

        return als;
    }

    public static ALSModel fit(Dataset<Row> data,String userCol,String itemCol,String ratingCol,
                               int rank,int numUserBlocks,int numItemBlocks,boolean implicitPrefs,
                               double alpha,int maxIterations,double regParam,boolean nonnegative,
                               int checkpointInterval,long seed,String intermediateStorageLevel,String finalStorageLevel){
        ALS als=getOperator(userCol,itemCol,ratingCol,rank,numUserBlocks,numItemBlocks,implicitPrefs,alpha,maxIterations,
                regParam,nonnegative,checkpointInterval,seed,intermediateStorageLevel,finalStorageLevel);
        ALSModel model=als.fit(data);
        return model;
    }

    public static ALSModel fit(ALS als,Dataset<Row> data){
        ALSModel model=als.fit(data);
        return model;
    }

    public static Dataset<Row> transform(ALSModel model, Dataset<Row> data){
        Dataset<Row> predictions=model.transform(data);
        return predictions;
    }
}