package org.apache.spark.ml.classification

import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.evaluation.BinaryClassificationMetricsExt
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, Row}

class BinaryLogisticRegressionSummaryExt private[classification](@transient override val predictions: DataFrame,
                                                                 override val probabilityCol: String,
                                                                 override val labelCol: String,
                                                                 override val featuresCol: String) extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) {


  private val sparkSession = predictions.sparkSession

  import sparkSession.implicits._

  /**
    * Returns a BinaryClassificationMetricsExt object.
    */
  // BinaryClassificationMetrics. For now the default is set to 100.
  @transient private val binaryMetrics = new BinaryClassificationMetricsExt(
    predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map {
      case Row(score: Vector, label: Double) => (score(1), label)
    }, 100
  )

  @transient lazy val gain: DataFrame = binaryMetrics.gains().toDF("reach", "recall")

  @transient lazy val lift: DataFrame = binaryMetrics.lift().toDF("reach", "lift")

  @transient lazy val ksMinus: DataFrame = binaryMetrics.ksMinus().toDF("reach", "FPR")

  @transient lazy val ksPlus: DataFrame = binaryMetrics.ksPlus().toDF("reach", "TPR")
}
