package org.apache.spark.mllib.fpgrowth

import java.io.PrintWriter

import com.datastax.insight.spec.RDDOperator
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.mllib.fpm.FPGrowth
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Dataset, Row, SparkSession}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object FPGrowthUtil extends Serializable {

  def fpgrowth(data: Dataset[Row], groupCol: String, targetCol: String, minSupport: Double, numPartitions: Int, minFreq: Long, p: Double, minItems: Int, uri: String, path: String): Unit = {
    require(data != null)
    require(groupCol != null && groupCol.trim.length != 0)
    require(targetCol != null && targetCol.trim.length != 0)
    require(minSupport > 0.0 && minSupport <= 1.0)
    require(numPartitions > 0)
    if(!checkExist(data.schema.fieldNames, groupCol, targetCol)) {
      throw new IllegalArgumentException("数据集中不存在参数中的一个或者多个列名")
    }
    data.cache()
    val grouped = data.select(groupCol).distinct().collect().filter(_.get(0) != null).map(_.getString(0))
    val fpg = new FPGrowth().setMinSupport(minSupport).setNumPartitions(numPartitions)

    import org.apache.spark.sql.functions.col
    var freqs: RDD[(String, Array[Any])] = null
    grouped.foreach(g => {
      val rdd = data.where(col(groupCol).equalTo(g)).select(targetCol).rdd.map(r => r.getAs[Seq[Any]](0).toArray.distinct)
      val model = fpg.run(rdd)
      if(freqs == null) {
        freqs = model.freqItemsets.filter(_.freq > minFreq).map(f => (g, f.items))
      } else {
        freqs = freqs.union(model.freqItemsets.filter(_.freq > minFreq).map(f => (g, f.items)))
      }
    })

    data.unpersist()
    freq2csv(mergeRelated(freqs, p, minItems), uri, path)
  }

  def fpgrowth1(data: Dataset[Row], groupCol: String, targetCol: String, minSupport: Double, numPartitions: Int, minFreq: Long): RDD[(String, Array[Any])] = {
    require(data != null)
    require(groupCol != null && groupCol.trim.length != 0)
    require(targetCol != null && targetCol.trim.length != 0)
    require(minSupport > 0.0 && minSupport <= 1.0)
    require(numPartitions > 0)
    if(!checkExist(data.schema.fieldNames, groupCol, targetCol)) {
      throw new IllegalArgumentException("数据集中不存在参数中的一个或者多个列名")
    }
    data.cache()
    val grouped = data.select(groupCol).distinct().collect().filter(_.get(0) != null).map(_.getString(0))
    val fpg = new FPGrowth().setMinSupport(minSupport).setNumPartitions(numPartitions)

    import org.apache.spark.sql.functions.col
    var freqs: RDD[(String, Array[Any])] = null
    grouped.foreach(g => {
      val rdd = data.where(col(groupCol).equalTo(g)).select(targetCol).rdd.map(r => r.getAs[Seq[Any]](0).toArray.distinct)
      val model = fpg.run(rdd)
      if(freqs == null) {
        freqs = model.freqItemsets.filter(_.freq > minFreq).map(f => (g, f.items))
      } else {
        freqs = freqs.union(model.freqItemsets.filter(_.freq > minFreq).map(f => (g, f.items)))
      }
    })

    data.unpersist()
    freqs
  }

  private def checkExist(fields: Array[String], cols: String*): Boolean = {
    cols.foreach(c => {
      if(!fields.contains(c)) {
        false
      }
    })
    true
  }

  def mergeRelated(rdd: RDD[(String, Array[Any])], p: Double, minItems: Int): RDD[(String, Array[Any])] = {
    require(p >= 0.0 && p <= 1.0)
    rdd.cache()
    val map = new mutable.HashMap[String, Array[Array[Any]]]()
//    rdd.foreach(r => {
//      if(r._2.length >= minItems) {
//        if(map.contains(r._1)) {
//          val array = map.get(r._1) match {
//            case Some(a) => a
//            case None => throw new NullPointerException("合并的Array为空")
//          }
//          var flag = true
//          val buffer = new ArrayBuffer[Array[Any]]
//          array.indices.foreach(i => {
//            val union = (r._2 ++ array(i)).distinct
//            if((union.length - array(i).length) < (p * r._2.length)) {
//              buffer += union
//              flag = false
//            } else {
//              buffer += array(i)
//            }
//            if(flag) {
//              buffer += r._2
//            }
//            map.put(r._1, buffer.toArray)
//          })
//        } else {
//          map.put(r._1, Array(r._2))
//        }
//      }
//    })

    val temp2 = rdd.groupBy(t => t._1).flatMap(t => {
      var buffer = new ArrayBuffer[Array[Any]]
      val iterable = t._2
      iterable.foreach(s => {
        if(s._2.length >= minItems) {
          if(buffer.isEmpty) {
            buffer += s._2
          } else {
            var flag = true
            var i = 0
            var len = buffer.length
            while(i < buffer.length) {
              val union = (buffer(i) ++ s._2).distinct
              if((buffer(i).length + s._2.length - union.length) >= p * s._2.length) {      //(union.length - buffer(i).length).toDouble < (p * s._2.length)
                if(flag) {
                  buffer = buffer.updated(i, union)
                } else {
                  buffer.remove(i)
                }
                flag = false
              }
              i += 1
            }
            if(flag) {
              buffer += s._2
            }
          }
        }
      })
      buffer.toArray.map(n => (t._1, n))
    })
//    temp2.take(20).foreach(t => println(t._1 + " >> " + t._2.mkString(",")))

    rdd.unpersist()
    temp2
  }

  private def freq2csvTemp(map: mutable.HashMap[String, Array[Array[Any]]], uri: String, path: String): Unit = {
    require(map != null)
    require(uri != null && uri.trim.length != 0)
    require(path != null && path.trim.length != 0)
    val conf = new Configuration()
    conf.set("fs.defaultFS", uri)
    val fileSys = FileSystem.get(conf)
    val output = fileSys.create(new Path(path))
    val writer = new PrintWriter(output)
    try {
      writer.write("stock,class_name,class_member\n")
      map.keySet.foreach(key => {
        map.get(key) match {
          case Some(array) => {
            array.indices.foreach(i => {
              writer.write(key + "," + i + "," + array(i).mkString(";"))
              if(i != array.length - 1) {
                writer.write("\n")
              }
            })
          }
          case None => println("none")
        }
      })
    } finally {
      writer.close()
      fileSys.close()
    }
  }

  def freq2csv(rdd: RDD[(String, Array[Any])], uri: String, path: String): Unit = {
    require(rdd != null)
    require(uri != null && uri.trim.length != 0)
    require(path != null && path.trim.length != 0)
    val conf = new Configuration()
    conf.set("fs.defaultFS", uri)
    val fileSys = FileSystem.get(conf)
    val output = fileSys.create(new Path(path))
    val writer = new PrintWriter(output)
    try {
      writer.write("stock,class_member\n")
      rdd.collect().foreach(t1 => {
        writer.write(t1._1 + "," + t1._2.mkString(";") + "\n")
      })

    } finally {
      writer.close()
      fileSys.close()
    }
  }
}
