package org.allenai.nlpstack.parse.poly.fsm

import org.allenai.nlpstack.parse.poly.ml.{ FeatureVector, FeatureName }

import reming.LazyFormat
import reming.DefaultJsonProtocol._

/** A StateCostFunction assigns a (real-valued) cost to the Transitions that can potentially
  * be applied to a State. Generally speaking: the lower the cost, the better
  * the transition.
  *
  * Typically, instances of StateCostFunction will compute this cost using a feature
  * representation of the State. But this is not always the case -- see the
  * GuidedCostFunction in [[org.allenai.nlpstack.parse.poly.polyparser.ArcEagerGuidedCostFunction]]
  * for a cost function that uses a gold parse tree as the basis for its cost function.
  */
abstract class StateCostFunction extends (State => Map[StateTransition, Float]) {

  def transitionSystem: TransitionSystem

  def lowestCostTransition(state: State): Option[StateTransition] = {
    val transitionCosts = this.apply(state)
    if (transitionCosts.isEmpty) {
      None
    } else {
      Some((transitionCosts minBy (_._2))._1)
    }
  }
}

case class ClassifierBasedCostFunction(
    transitionSystem: TransitionSystem, transitions: Seq[StateTransition],
    taskClassifierList: List[(ClassificationTask, TransitionClassifier)],
    marbleBlock: MarbleBlock,
    baseCostFunction: Option[StateCostFunction]
) extends StateCostFunction {

  @transient
  lazy val taskClassifiers = taskClassifierList.toMap

  override def apply(state: State): Map[StateTransition, Float] = {
    transitionCosts(state, 0.0f)
  }

  /** Returns a distribution over all possible transitions, according to the classifier associated
    * with the given task.
    *
    * The return value will be a map from transitions to their probabilities. The `minProb`
    * argument tells this function not to bother including transitions whose probabilities are
    * less than `minProb`.
    *
    * If there were no training examples for the task, then the uniform distribution is returned.
    *
    * @param state the parser state
    * @param minProb only include transitions in the returned map if their
    * probability is greater than this bound
    * @return a map from transitions to their probabilities
    */
  private def transitionDistribution(
    state: State,
    minProb: Float
  ): Map[StateTransition, Float] = {

    transitionSystem.taskIdentifier(state) match {
      case Some(task) =>
        val featureVector: FeatureVector = transitionSystem.computeFeature(state)
        val topLevelDistribution: Map[StateTransition, Float] = {
          if (!taskClassifiers.contains(task)) {
            transitions.zip(transitions.map { _ =>
              1.0f / transitions.size
            }).toMap
          } else {
            taskClassifiers(task).getDistribution(featureVector) filter {
              case (transition, prob) => prob >= minProb
            }
          }
        }
        val result = if (topLevelDistribution.contains(Fallback)) {
          require(baseCostFunction != None)
          val baseCosts: Map[StateTransition, Float] = (baseCostFunction.get)(state)
          val baseDistribution: Map[StateTransition, Float] = baseCosts mapValues (x => Math.exp(-x).toFloat)
          val fallbackProb: Float = topLevelDistribution(Fallback)
          val topLevelDistributionWithoutFallback: Map[StateTransition, Float] = topLevelDistribution - Fallback
          (for {
            key <- baseDistribution.keys ++ topLevelDistributionWithoutFallback.keys
          } yield (key,
            topLevelDistributionWithoutFallback.getOrElse(key, 0.0f) +
            fallbackProb * baseDistribution.getOrElse(key, 0.0f))).toMap
        } else {
          topLevelDistribution
        }
        result
      case None => Map()
    }
  }

  /** Returns negative log of [[transitionDistribution()]].
    *
    * This is the negative log of a distribution over all possible transitions,
    * according to the classifier associated with the given task.
    *
    * The return value will be a map from transitions to their probabilities. The `minProb`
    * argument tells this function not to bother including transitions whose probabilities are
    * less than `minProb`.
    *
    * If there were no training examples for the task, then the uniform distribution is used.
    *
    * @param state the parser state
    * @param minProb only include transitions in the returned map if their
    * probability is greater than this bound
    * @return a map from transitions to negative log of their neprobabilities
    */
  private def transitionCosts(
    state: State,
    minProb: Float
  ): Map[StateTransition, Float] = {

    transitionDistribution(state, minProb) mapValues (-Math.log(_).toFloat)
  }

}

trait StateCostFunctionFactory {
  def buildCostFunction(
    marbleBlock: MarbleBlock,
    constraints: Set[TransitionConstraint]
  ): StateCostFunction
}

object StateCostFunctionFactory {
  implicit object StateCostFunctionFactoryFormat extends LazyFormat[StateCostFunctionFactory] {
    private implicit val classifierBasedCostFunctionFactoryFormat =
      jsonFormat4(ClassifierBasedCostFunctionFactory.apply)

    override val delegate = parentFormat[StateCostFunctionFactory](
      childFormat[ClassifierBasedCostFunctionFactory, StateCostFunctionFactory]
    )
  }
}

case class ClassifierBasedCostFunctionFactory(
    transitionSystemFactory: TransitionSystemFactory,
    transitions: Seq[StateTransition],
    taskClassifierList: List[(ClassificationTask, TransitionClassifier)],
    baseCostFunctionFactory: Option[StateCostFunctionFactory] = None
) extends StateCostFunctionFactory {

  def buildCostFunction(
    marbleBlock: MarbleBlock,
    constraints: Set[TransitionConstraint]
  ): StateCostFunction = {

    val transitionSystem = transitionSystemFactory.buildTransitionSystem(marbleBlock, constraints)
    val baseCostFunction = baseCostFunctionFactory map { fact =>
      fact.buildCostFunction(marbleBlock, constraints)
    }
    new ClassifierBasedCostFunction(transitionSystem, transitions, taskClassifierList,
      marbleBlock, baseCostFunction)
  }
}
