
package com.datastax.data.prepare.spark.dataset.hierarchicalCluster;

import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.data.prepare.spark.dataset.hierarchicalCluster.entry.Edge;
import com.datastax.data.prepare.spark.dataset.hierarchicalCluster.entry.Point;
import com.datastax.data.prepare.spark.dataset.hierarchicalCluster.writable.EdgeWritable;
import com.datastax.data.prepare.spark.dataset.hierarchicalCluster.writable.PointWritable;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.Tuple2;

import java.net.URI;
import java.util.*;


public class DataSplitter {
    private final JavaSparkContext sc;
    private final Tuple2<Map,Map> userId_idUserMap;

    <T> DataSplitter(Dataset<T> dataset) {
        this.sc = initSparkContext();
        this.userId_idUserMap = getUserIdMap(dataset);
    }

    public <T>  void saveHadoopFileEdge(Dataset<T> dataset,int numSplits,String outputDir) {
        JavaRDD<Edge> javaRDD = dataset.toJavaRDD().map(new DataSplitter.ToEdge<T>(userId_idUserMap._1));
        saveAsSequenceFileEdge(javaRDD,numSplits,outputDir);
    }

    private <T> Tuple2<Map,Map> getUserIdMap(Dataset<T> dataset) {
        String[] columns = dataset.columns();
        Dataset ds1 = dataset.select(columns[0]).distinct();
        Dataset ds2 = dataset.select(columns[1]).distinct();
        Set<String> userIds = new HashSet();
        for (Row row: (List<Row>)ds1.collectAsList()) {
            userIds.add(row.getString(0));
        }
        for (Row row:  (List<Row>)ds2.collectAsList()) {
            userIds.add(row.getString(0));
        }
        Map mapuser2id = new HashMap();
        Map mapid2user = new HashMap();
        int id = 0;
        for (String str:userIds) {
            mapuser2id.put(str,id);
            mapid2user.put(id,str);
            id++;
        }
        return new Tuple2<Map,Map>(mapuser2id,mapid2user);
    }

    public JavaSparkContext initSparkContext() {
        JavaSparkContext javaSparkContext = SparkContextBuilder.getJContext();

        return javaSparkContext;
    }

    public JavaSparkContext getSparkContext() {
        return sc;
    }

    public  int getNumPonits() {
        return this.userId_idUserMap._1.size();
    }

    public Map getIdUserMap() {
        return this.userId_idUserMap._2;
    }


    public void createPartitionFiles(String fileLoc,int numSubGraphs) {
        List<String> idSubgraphs = Lists.newArrayListWithCapacity(numSubGraphs);
        for (int i = 0; i < numSubGraphs; i++) {
            idSubgraphs.add(String.valueOf(i));
        }
        System.out.println("create idSubgraph files: " + fileLoc);
        try {
            sc.parallelize(idSubgraphs, numSubGraphs).saveAsTextFile(fileLoc);
        } catch (Exception e) {
            return;
        }
    }

    private void saveAsSequenceFileEdge(JavaRDD<Edge> javaRDD, int numSplits, String outputDir) {
        JavaPairRDD<NullWritable, EdgeWritable> pointsPairToWrite = javaRDD.coalesce(numSplits,true).mapToPair(new ToPairFunctionEdge());
        try {
            pointsPairToWrite.saveAsHadoopFile(outputDir, NullWritable.class, EdgeWritable.class,
                    SequenceFileOutputFormat.class);
        } catch (Exception e) {
            return;
        }

    }

    public static final class ToEdge<T> implements Function<T, Edge> {
        private static final long serialVersionUID = 1L;
        private final Map map;
        public ToEdge (Map map) {
            this.map = map;
        }
        @Override
        public Edge call(T o) throws CloneNotSupportedException {
            String str = o.toString().replace("[","").replace("]","").trim();
            String[] split = str.toString().split(",");
            Integer left = Integer.valueOf(map.get(split[0]).toString());
            Integer right = Integer.valueOf(map.get(split[1]).toString());
            Double weight = -new Double(split[2]);
            Edge edge = new Edge(left,right,weight);
            return edge;
        }
    }

    public static final class ToPairFunction implements PairFunction<Point, NullWritable, PointWritable> {
        private static final long serialVersionUID = 1L;

        @Override
        public Tuple2<NullWritable, PointWritable> call(Point row) throws CloneNotSupportedException {
            return new Tuple2<NullWritable, PointWritable>(NullWritable.get(), new PointWritable(row.clone()));
        }
    }
    public static final class ToPairFunctionEdge implements PairFunction<Edge, NullWritable, EdgeWritable> {
        private static final long serialVersionUID = 1L;

        @Override
        public Tuple2<NullWritable, EdgeWritable> call(Edge row) throws CloneNotSupportedException {
            return new Tuple2<NullWritable, EdgeWritable>(NullWritable.get(), new EdgeWritable(row.clone()));
        }
    }

    public static void deleteHdfsFile(String hdfsFile) {
        try {
            FileSystem hdfs = FileSystem.get(new URI(hdfsFile),new Configuration());
            Path path = new Path(hdfsFile);
            if (hdfs.exists(path)) {
                hdfs.delete(path,true);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
