package com.datastax.insight.ml.spark.data.rdd;

import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.Consts;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;

import java.util.ArrayList;
import java.util.List;

public class RDDTransformation implements RDDOperator {
    public static JavaRDD<Vector> denseVector(JavaRDD<String> rdd, String delimiter){

        JavaRDD<Vector> vectors = rdd.map(
                (Function<String, Vector>) s -> {
                    String delim=delimiter;
                    if(delim==null || delim.length()==0) {
                        delim= Consts.DELIMITER;
                    }
                    String[] sarray = s.split(delim);
                    double[] values = new double[sarray.length];
                    for (int i = 0; i < sarray.length; i++) {
                        values[i] = Double.parseDouble(sarray[i]);
                    }
                    return Vectors.dense(values);
                }
        );
        return vectors;
    }

    public static JavaRDD<LabeledPoint> lpRDD(JavaRDD<String> rdd,String delimiter){
        JavaRDD<LabeledPoint> labeledPoints=rdd.map(
                (Function<String, LabeledPoint>) s -> {
                    String delim=delimiter;
                    if(delim==null || delim.length()==0) {
                        delim= Consts.DELIMITER;
                    }
                    String[] sarray = s.split(delim);
                    double[] values = new double[sarray.length];
                    for (int i = 1; i < sarray.length; i++) {
                        values[i] = Double.parseDouble(sarray[i]);
                    }
                    return new LabeledPoint(Double.parseDouble(sarray[0]), Vectors.dense(values));
                }
        );
        return labeledPoints;
    }

    public static JavaRDD<LabeledPoint>[] split(JavaRDD<LabeledPoint> data, String weights){
        String[] texts=weights.split(Consts.DELIMITER);
        double[] ws=new double[texts.length];
        for(int i=0;i<ws.length;i++){
            ws[i]=Double.parseDouble(texts[i]);
        }
        return data.randomSplit(ws);
    }

    public static <T>  JavaRDD<T> distinct(JavaRDD<T> rdd,int numPartitions){
        if(numPartitions>0){
            return rdd.distinct(numPartitions);
        }else {
            return rdd.distinct();
        }
    }

    public static <T> JavaRDD<T> coalesce(JavaRDD<T> rdd,int numPartitions,boolean shuffle){
        return rdd.coalesce(numPartitions,shuffle);
    }

    public static <T> JavaRDD<T> repartition(JavaRDD<T> rdd,int numPartitions){
        return rdd.repartition(numPartitions);
    }

    public static <T> JavaRDD<T> sample(JavaRDD<T> rdd,boolean withReplacement,double fraction,long seed){
        return rdd.sample(withReplacement,fraction,seed);
    }

    public static <T> JavaRDD<T> union(JavaRDD<T> rdd,JavaRDD<T> other){
        return rdd.union(other);
    }

    public static <T> JavaRDD<T> intersection(JavaRDD<T> rdd,JavaRDD<T> other){
        return rdd.intersection(other);
    }

    public static <T> JavaRDD<T> subtract(JavaRDD<T> rdd,JavaRDD<T> other,int numPartitions){
        return rdd.subtract(other,numPartitions);
    }

    public static <T> JavaRDD<T> setName(JavaRDD<T> rdd,String name){
        return rdd.setName(name);
    }

    public static <T> JavaRDD<List<T>> glom(JavaRDD<T> rdd){
        return rdd.glom();
    }

    public static JavaRDD<String> pipe(JavaRDD<String> rdd,String... command){
        List<String> list=new ArrayList<>();
        for(String c : command){
            list.add(c);
        }
        return rdd.pipe(list);
    }

    public static <T> JavaPairRDD<T, Long> zip(JavaRDD<T> rdd, String type){
        if(type.equals("uniqueid")){
            return rdd.zipWithUniqueId();
        }else if(type.equals("index")){
            return rdd.zipWithIndex();
        }
        return null;
    }

    public static <T> List<T> collect(JavaRDD<T> rdd){
        return rdd.collect();
    }

    public static <T> List<T>[] collectPartitions(JavaRDD<T> rdd, String pid){
        String[] pids=pid.split(Consts.DELIMITER);
        int[] partitionIds=new int[pids.length];
        for(int i=0;i<pids.length;i++){
            partitionIds[i]=Integer.parseInt(pids[i]);
        }
        return rdd.collectPartitions(partitionIds);
    }

    public static <T> Object count(JavaRDD<T> rdd,boolean byValue,long timeout,double confidence){
        if(byValue){
            if(timeout>0){
                return rdd.countByValueApprox(timeout,confidence);
            }else {
                return rdd.countByValue();
            }
        }else {
            if(timeout>0){
                return rdd.countApprox(timeout,confidence);
            }else {
                return rdd.count();
            }
        }
    }

    public static <T> List<T> take(JavaRDD<T> rdd, int num){
        return rdd.take(num);
    }

    public static <T> List<T> takeOrdered(JavaRDD<T> rdd, int num){
        return rdd.takeOrdered(num);
    }

    public static <T> List<T> top(JavaRDD<T> rdd, int num){
        return rdd.top(num);
    }

    public static <T> List<T> takeSample(JavaRDD<T> rdd, boolean withReplacement,int num,long seed){
        return rdd.takeSample(withReplacement,num,seed);
    }

    public static <T> T first(JavaRDD<T> rdd){
        return rdd.first();
    }

}
