package com.datastax.data.prepare.spark.dataset

import java.io.PrintWriter

import com.datastax.insight.core.driver.SparkContextBuilder
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.mllib.fpm.FPGrowth
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object FPGrowthUtil extends Serializable {

  def fpgrowth(spark: SparkSession, data: Dataset[Row], groupCol: String, targetCol: String, minSupport: Double, numPartitions: Int, minFreq: Long, p: Double, minItems: Int, uri: String, path: String): Unit = {
    println("fpgrowth开始....")
    import spark.implicits._
    require(data != null)
    require(groupCol != null && groupCol.trim.length != 0)
    require(targetCol != null && targetCol.trim.length != 0)
    require(minSupport > 0.0 && minSupport <= 1.0)
    if(!checkExist(data.schema.fieldNames, groupCol, targetCol)) {
      throw new IllegalArgumentException("数据集中不存在参数中的一个或者多个列名")
    }
    val grouped = data.select(groupCol).distinct().collect().filter(_.get(0) != null).map(_.getString(0))
    val fpg = new FPGrowth().setMinSupport(minSupport)
    if(numPartitions > 0) {
      fpg.setNumPartitions(numPartitions)
    }

    import org.apache.spark.sql.functions._
    var freqs: RDD[(String, Array[Any])] = null
    grouped.foreach(g => {
      val start = System.currentTimeMillis()
      println(g + " 开始执行FPGrowth...")
      val rdd = data.where(col(groupCol).equalTo(g)).select(targetCol).rdd.map(r => r.getAs[Seq[Any]](0).toArray.distinct)
      val model = fpg.run(rdd)

      freqs = model.freqItemsets.filter(_.freq >= minFreq).map(f => (g, f.items))
//      spark.createDataset(mergeRelated(freqs, p, minItems).map(s => (s._1, s._2.map(_.toString)))).toDF("stock", "class_member").select(col("stock"),
//        concat_ws(";", col("class_member")).alias("class_member")).repartition(1).write.option("header", true).mode(SaveMode.Overwrite)
//          .csv(uri + path + "/" + g)
      spark.createDataset(freqs.map(s => (s._1, s._2.map(_.toString)))).toDF("stock", "class_member").select(col("stock"),
        concat_ws(";", col("class_member")).alias("class_member")).repartition(1).write.option("header", true).mode(SaveMode.Overwrite)
        .csv(uri + path + "/" + g)

      println(g + " FPGrowth结束")
      val end = System.currentTimeMillis()
      println(g + " FPGrowth执行消耗时间为：" + (end - start))
    })
    println("fpgrowth结束")
//    freq2csv(mergeRelated(freqs, p, minItems), uri, path)
  }


  private def checkExist(fields: Array[String], cols: String*): Boolean = {
    cols.foreach(c => {
      if(!fields.contains(c)) {
        false
      }
    })
    true
  }

  def mergeRelated(rdd: RDD[(String, Array[Any])], p: Double, minItems: Int): RDD[(String, Array[Any])] = {
    println("mergeRelated开始。。。")
    require(p >= 0.0 && p <= 1.0, "p小于0或者大于1")
    require(rdd != null, "rdd为空")
    require(minItems >= 0, "minItems小于0")

    val temp2 = rdd.groupBy(t => t._1).flatMap(t => {
      var buffer = new ArrayBuffer[Array[Any]]
      val iterable = t._2
      iterable.foreach(s => {
        if(s._2.length >= minItems) {
          if(buffer.isEmpty) {
            buffer += s._2
          } else {
            var flag = true
            var i = 0
            var len = buffer.length
            while(i < buffer.length) {
              val union = (buffer(i) ++ s._2).distinct
              if((buffer(i).length + s._2.length - union.length) >= p * s._2.length) {      //(union.length - buffer(i).length).toDouble < (p * s._2.length)
                if(flag) {
                  buffer = buffer.updated(i, union)
                } else {
                  buffer.remove(i)
                }
                flag = false
              }
              i += 1
            }
            if(flag) {
              buffer += s._2
            }
          }
        }
      })
      val result = buffer.toArray.map(n => (t._1, n))
      buffer.clear()
      result
    })
    println("mergeRelated结束。")
    temp2
  }


  def freq2csvTemp2(rdd: RDD[(String, Array[Any])], uri: String, path: String): Unit = {
    require(rdd != null)
    require(uri != null && uri.trim.length != 0)
    require(path != null && path.trim.length != 0)
    val conf = new Configuration()
    conf.set("fs.defaultFS", uri)
    val fileSys = FileSystem.get(conf)
    val output = fileSys.create(new Path(path))
    val writer = new PrintWriter(output)
    try {
      writer.write("stock,class_member\n")
      rdd.collect().foreach(t1 => {
        writer.write(t1._1 + "," + t1._2.mkString(";") + "\n")
      })

    } finally {
      writer.close()
      fileSys.close()
    }
  }

  def freq2csv(rdd: RDD[(String, Array[Any])], uri: String, path: String): Unit = {
    println("freq2csv开始。。。")
    require(rdd != null)
    require(uri != null && uri.trim.length != 0)
    require(path != null && path.trim.length != 0)
    val temp = rdd.map(s => Row(s._1, s._2.mkString(";")))
    val spark = SparkContextBuilder.getSession
    val schema = StructType(StructField("stock", StringType, true) :: StructField("class_member", StringType, true) :: Nil)
    spark.createDataFrame(temp, schema).repartition(1).write.option("header", true).mode(SaveMode.Overwrite).csv(uri + "/" + path)
    println("freq2csv结束。")
  }

  @deprecated
  def dataset(spark: SparkSession): DataFrame = {
    val path1 = "/home/keqc/dataexa_work/sati_data/stock/data/SUB_REAL_KG_DATA.csv"
    val data = spark.read.option("header", false).csv(path1).toDF("Mkt_type", "shr_acct", "area_code",
      "Sec_code", "seat_code", "Trad_date", "Trad_time", "Trad_dirc", "Trad_vol", "Trad_price")
    data.cache()
    val days = 1
    val mergeDate = udf((dates: Seq[String], ids: mutable.Seq[Seq[String]]) => {
      if(days <= 1 || days > dates.length) {
        ids
      } else {
        val buffer1 = Seq.newBuilder[Seq[String]]
        Range(0, dates.length - days + 1).foreach(i => {
          var buffer2 = Seq.newBuilder[String]
          Range(0, days).foreach(j => {
            buffer2 ++= ids(i + j)
          })
          buffer1 += buffer2.result()
        })
        buffer1.result()
      }
    })
    import org.apache.spark.sql.functions._
    import spark.implicits._
    val df = data.groupBy(col("Sec_code"), col("Trad_date"), col("Trad_dirc")).agg(collect_list("Shr_acct")).sort(col("Trad_date"))
      .groupBy($"Sec_code", $"Trad_dirc").agg(collect_list("Trad_date"), collect_list("collect_list(Shr_acct)"))
      .withColumnRenamed("collect_list(Trad_date)", "dates").withColumnRenamed("collect_list(collect_list(Shr_acct))", "ids")
      .select("Sec_code", "Trad_dirc", "dates", "ids").withColumn("merge", mergeDate(col("dates"), col("ids")))
      .withColumn("result", explode(col("merge"))).sort("Sec_code").select("Sec_code", "result")
    //    df.show(30, 500)
    data.unpersist()
    df
  }

  @deprecated
  def fpgrowth1(data: Dataset[Row], groupCol: String, targetCol: String, minSupport: Double, numPartitions: Int, minFreq: Long): RDD[(String, Array[Any])] = {
    require(data != null)
    require(groupCol != null && groupCol.trim.length != 0)
    require(targetCol != null && targetCol.trim.length != 0)
    require(minSupport > 0.0 && minSupport <= 1.0)
    require(numPartitions > 0)
    if(!checkExist(data.schema.fieldNames, groupCol, targetCol)) {
      throw new IllegalArgumentException("数据集中不存在参数中的一个或者多个列名")
    }
    val grouped = data.select(groupCol).distinct().collect().filter(_.get(0) != null).map(_.getString(0))
    val fpg = new FPGrowth().setMinSupport(minSupport).setNumPartitions(numPartitions)

    import org.apache.spark.sql.functions.col
    var freqs: RDD[(String, Array[Any])] = null
    grouped.foreach(g => {
      val rdd = data.where(col(groupCol).equalTo(g)).select(targetCol).rdd.map(r => r.getAs[Seq[Any]](0).toArray.distinct)
      val model = fpg.run(rdd)
      if(freqs == null) {
        freqs = model.freqItemsets.filter(_.freq > minFreq).map(f => (g, f.items))
      } else {
        freqs = freqs.union(model.freqItemsets.filter(_.freq > minFreq).map(f => (g, f.items)))
      }
    })

    freqs
  }

  @deprecated
  private def freq2csvTemp(map: mutable.HashMap[String, Array[Array[Any]]], uri: String, path: String): Unit = {
    require(map != null)
    require(uri != null && uri.trim.length != 0)
    require(path != null && path.trim.length != 0)
    val conf = new Configuration()
    conf.set("fs.defaultFS", uri)
    val fileSys = FileSystem.get(conf)
    val output = fileSys.create(new Path(path))
    val writer = new PrintWriter(output)
    try {
      writer.write("stock,class_name,class_member\n")
      map.keySet.foreach(key => {
        map.get(key) match {
          case Some(array) => {
            array.indices.foreach(i => {
              writer.write(key + "," + i + "," + array(i).mkString(";"))
              if(i != array.length - 1) {
                writer.write("\n")
              }
            })
          }
          case None => println("none")
        }
      })
    } finally {
      writer.close()
      fileSys.close()
    }
  }

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("keqc").master("local").getOrCreate()
    val df = dataset(spark)

    val groupCol = "Sec_code"
    val targetCol = "result"
    val minSupport = 0.3
    val numPartitions = 4
    val minFreq = 0
    val rdd = fpgrowth1(df, groupCol, targetCol, minSupport, numPartitions, minFreq)
    //    rdd.take(50).foreach(t => println(t._1 + ", " + t._2.mkString(",")))
    //    println(rdd.count())

    val t = mergeRelated(rdd, 0.5, 2)
    //    map.keySet.foreach(key => println(key + " -> " + map.get(key).get.mkString(",")))

    val uri = "hdfs://dataexa-cdh-80:8020"
    val path = "/dataexa/sati/stock/group/test1.csv"
    freq2csv(t, uri, path)

    spark.close()
  }
}
