package com.datastax.insight.ml.spark.mllib.association;

import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.Consts;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.fpgrowth.FPGrowthUtil;
import org.apache.spark.mllib.fpm.AssociationRules;
import org.apache.spark.mllib.fpm.FPGrowth;
import org.apache.spark.mllib.fpm.FPGrowthModel;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class FPGrowthEvaluator implements RDDOperator {
    public static FPGrowthModel evaluate(JavaRDD<String> rdd, String delimiter,double minSupport,int numPartitions,double confidence){
        JavaRDD<List<String>> data=rdd.map((Function<String, List<String>>) line -> {
            String delim=delimiter;
            if(delim==null || delim.length()==0) {
                delim= Consts.DELIMITER;
            }
            List<String> list=new ArrayList<>();
            String[] items= line.split(delim);
            Collections.addAll(list, items);
            return list;
        });
        FPGrowthModel model=new FPGrowth().setMinSupport(minSupport).setNumPartitions(numPartitions).run(data);

        if(confidence>0){
            List<AssociationRules.Rule<String>> rules=model.generateAssociationRules(confidence).toJavaRDD().collect();
            for(AssociationRules.Rule<String> rule : rules){
                System.out.println(rule.javaAntecedent()+"===>"+rule.javaConsequent()+"，"+rule.confidence());
            }
        }
        //printItemset(model.freqItemsets());

        return model;
    }

    private static void printItemset(RDD<FPGrowth.FreqItemset<String>> rdd){
        List<FPGrowth.FreqItemset<String>> list=rdd.toJavaRDD().collect();
        for(FPGrowth.FreqItemset<String> itemset : list){
            System.out.println(itemset.javaItems()+"===>"+itemset.freq());
        }
    }

    public static <T> void fpgrowth(Dataset<T> data, String groupCol, String targetCol, double minSupport, int numPartitions, long minFreq,
                                    double p, int minItems, String uri, String path) {
        FPGrowthUtil.fpgrowth(data.toDF(), groupCol, targetCol, minSupport, numPartitions, minFreq, p, minItems, uri, path);
    }

    private static RDD mergeRelated(RDD rdd, Double p, Integer minItems) {
        return FPGrowthUtil.mergeRelated(rdd, p, minItems);
    }

    private static void freq2csv(RDD rdd, String uri, String path) {
        FPGrowthUtil.freq2csv(rdd, uri, path);
    }
}
