package com.datastax.data.prepare.spark.dataset
import java.util

import com.datastax.data.prepare.util.Consts
import com.datastax.insight.core.driver.SparkContextBuilder
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

import scala.collection.mutable

object BasicOperation {

  def mathCompute(df: DataFrame, xColumn: String, method: String, yColumn: String, valueType: String, newColumnName: String): DataFrame = {
    require(newColumnName != null)
    df.withColumn(newColumnName, getColumn(method, df, xColumn, yColumn, valueType))
  }

  private def getColumn(method: String, df: DataFrame, xCol: String, yCol: String, valueType: String): Column = {
    if("constant".equals(valueType)) {
      val d = yCol.toDouble
      if ("plus".equals(method)) {
        df(xCol) + d
      } else if ("minus".equals(method)) {
        df(xCol) - d
      } else if ("multiply".equals(method)) {
        df(xCol) * d
      } else if ("divide".equals(method)) {
        df(xCol) / d
      } else {
        df(xCol)
      }
    } else {
      if ("plus".equals(method)) {
        df(xCol) + df(yCol)
      } else if ("minus".equals(method)) {
        df(xCol) - df(yCol)
      } else if ("multiply".equals(method)) {
        df(xCol) * df(yCol)
      } else if ("divide".equals(method)) {
        df(xCol) / df(yCol)
      } else {
        df(xCol)
      }
    }

  }

  def groupFilter(df: DataFrame, oldColumn: String, compare: Consts.MathCompare, threshold: Double): DataFrame = {
    require(oldColumn != null)
    import org.apache.spark.sql.functions._
    df.filter(compare.compare(col(oldColumn), threshold))
  }

  def createData(data:mutable.Buffer[String], columns:String):DataFrame = {
    val spark = SparkContextBuilder.getSession
    val schema = {
      val fields = columns.split(",").map(fieldName => StructField(fieldName, StringType))
      StructType(fields)
    }
    val rows = new util.ArrayList[Row]()
    for(r <- data) {
      rows.add(Row.fromSeq(r.split(",")))
    }
    spark.createDataFrame(rows, schema)
  }
}
