package com.datastax.data.prepare.spark.math

import spire.ClassTag

import scala.annotation.tailrec
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

object ArrayCorrelation {

  @tailrec
  def seekNumberOfK(array: Array[Double], k: Int): Double = {
    val first = array(0)
    val left = new ArrayBuffer[Double]
    val right = new ArrayBuffer[Double]
    var count = 1  // 用于计算与first相等的数量
    for(i <- 1 until array.length) {
      if(array(i) < first) {
        left += array(i)
      }
      if(array(i) == first) {
        count += 1
      }
      if(array(i) > first) {
        right += array(i)
      }
    }
    if((left.isEmpty && right.isEmpty) || (left.size >= k - count && left.size <= k)) {
      right.clear()
      left.clear()
      first
    } else if(left.size > k) {
      val copy = left.toArray
      right.clear()
      left.clear()
      seekNumberOfK(copy, k)
    } else {
      val copy = right.toArray
      val size = left.size
      right.clear()
      left.clear()
      seekNumberOfK(copy, k - size - count)
    }
  }

  def seekPreNumberOfK[T : ClassTag](array: Array[T], k: Int)(f: (T, T) => Boolean): Array[T] = {
    val preK = new ListBuffer[T]
    var i = 1
    preK += array(0)
    while (i < array.length) {
      if(preK.size < k) {
        if(!f(preK.last, array(i))) {
          preK += array(i)
        } else {
          var j = preK.length - 2
          while(j > -1 && f(preK(j), array(i))) {
            j -= 1
          }
          preK.insert(j + 1, array(i))
        }
      } else {
        if(f(preK.last, array(i))) {
          var j = preK.length - 2
          while(j > -1 && f(preK(j), array(i))) {
            j -= 1
          }
          preK.insert(j + 1, array(i))
        }

        // 防止f(preK.last, array(i))将preK == array(i)排除在外, 简单的就是 f(_ > _)和f(_ >= _)的区别
        if(!f(preK.last, preK.last) && preK.last == array(i)) {
          preK += array(i)
        }
        //preK长度超过ｋ, 删除长度多于ｋ而且不等于preK(k - 1)的元素
        var a = k
        while(a != preK.size && preK(k - 1) == preK(a)) {
          a += 1
        }
        if(preK.size != a) {
          preK.remove(a, preK.size - a)
        }
      }
      i = i + 1
    }
    preK.toArray[T]
  }



  def main(args: Array[String]): Unit = {
    val start = System.currentTimeMillis()
    val array = Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)
    val array1 = Array(10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0)
    val array2 = Array(10.0, 1.0, 8.0, 2.0, 6.0, 3.0, 4.0, 5.0, 7.0, 6.0)
    val array3 = Array(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
    val array4 = Array(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0)
    val array5 = Array(1.0, 5.0, 3.0, 2.0, 3.0, 5.0, 5.0, 3.0, 3.0, 4.0, 6.0, 2.0, 1.0, 5.0)
    val d = seekPreNumberOfK(array5, 5)(_ < _)
    println(d.mkString(","))
    println(seekNumberOfK(array4, 12))
    val end = System.currentTimeMillis()
    println(end - start)
  }

}
