package kafka4m.consumer

import java.util.Properties
import java.{time, util}

import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging
import kafka4m.data.{CommittedStatus, KafkaPartitionInfo, PartitionOffsetState}
import kafka4m.util.{FixedScheduler, Schedulers}
import monix.catnap.ConcurrentQueue
import monix.eval.Task
import monix.execution.{BufferCapacity, Scheduler}
import monix.reactive.Observable
import org.apache.kafka.clients.consumer._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.Deserializer

import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.concurrent.{Future, Promise}
import scala.jdk.CollectionConverters._
import scala.util.Try
import scala.util.control.NonFatal

/**
  * A means of driving a kafka-stream using the consumer (not kafka streaming) API
  */
final class RichKafkaConsumer[K, V] private (val consumer: KafkaConsumer[K, V],
                                             val topics: Set[String],
                                             val defaultPollTimeout: Duration,
                                             commandQueue: ConcurrentQueue[Task, ExecOnConsumer[K, V, _]],
                                             kafkaScheduler: Scheduler,
                                             val asyncScheduler: Scheduler,
                                             startPollingOnStart: Boolean = true)
    extends AutoCloseable
    with ConsumerAccess
    with StrictLogging { self =>

  override type Key   = K
  override type Value = V

  @volatile private var closed = false

  require(topics.nonEmpty, "empty topic set for consumer")
  require(topics.forall(_.nonEmpty), "blank topic set for consumer")

  private val javaPollDuration: time.Duration = RichKafkaConsumer.asJavaDuration(defaultPollTimeout)

  def partitionsByTopic(limitToOurTopic: Boolean = true): Map[String, List[KafkaPartitionInfo]] = {
    val view = consumer.listTopics().asScala.view.mapValues(_.asScala.map(KafkaPartitionInfo.apply).toList)
    if (limitToOurTopic) {
      view.filterKeys(topics.contains).toMap
    } else {
      view.toMap
    }
  }

  def subscribe(topic: String, listener: ConsumerRebalanceListener = RebalanceListener): Unit = {
    logger.info(s"Subscribing to $topic")
    consumer.subscribe(java.util.Collections.singletonList(topic), listener)
  }

  def partitions: List[KafkaPartitionInfo] = {
    val byTopic = partitionsByTopic(true)
    topics.toList.flatMap(byTopic.getOrElse(_, Nil))
  }

  /**
    * this poll is unsafe as it will fail if invoked from a different thread from which this consumer was created
    * @param timeout
    * @return the records pull from Kafka
    */
  def unsafePoll(timeout: time.Duration = javaPollDuration): Iterable[ConsumerRecord[K, V]] = {
    try {
      val records: ConsumerRecords[K, V] = consumer.poll(timeout)
      logger.debug(s"Got ${records.count()} records from ${records.partitions().asScala.mkString(s"[", ",", "]")}")
      val forTopic: Iterable[ConsumerRecord[K, V]] = {
        records.asScala.filter { record =>
          topics.contains(record.topic())
        }
      }
      logger.trace(s"Got ${forTopic.size} of ${records.count()} for topic '$topics' records from ${records.partitions().asScala.mkString(s"[", ",", "]")}")
      forTopic
    } catch {
      case NonFatal(e) =>
        logger.warn(s"Poll threw $e")
        throw e
    }
  }

  private val NoResults = Observable.empty[ConsumerRecord[K, V]]

  /**
    * Represent this consumer as an observable
    * @param closeOnComplete set to true if the underlying Kafka consumer should be closed when this observable completes
    */
  def asObservable(closeOnComplete: Boolean): Observable[ConsumerRecord[K, V]] = {
    val records: Task[Observable[ConsumerRecord[K, V]]] = {
      commandQueue.tryPoll.flatMap {
        // try and handle any explicit commands, but if none are queued, then fall-back to polling kafka
        case None =>
          Task.eval(unsafePoll()).executeOn(kafkaScheduler).map { returned =>
            if (returned.isEmpty) {
              NoResults
            } else {
              Observable.fromIterable(returned)
            }
          }
        case Some(exec: ExecOnConsumer[K, V, _]) =>
          Task(exec.run(self)).executeOn(kafkaScheduler).map(_ => NoResults)
      }
    }

    val obs: Observable[ConsumerRecord[K, V]] = Observable.repeatEvalF(records).flatten.observeOn(asyncScheduler)
    if (closeOnComplete) {
      obs.guarantee(Task.delay(close()).executeOn(kafkaScheduler))
    } else {
      obs
    }
  }

  /**
    * @return a task which will run any exec commands on our kafka scheduler
    */
  def execNext() = {
    require(!closed, "RickKafkaConsumer is already closed")
    commandQueue.tryPoll.flatMap {
      case Some(exec: ExecOnConsumer[K, V, _]) =>
        Task(exec.run(self)).executeOn(kafkaScheduler).map(_ => NoResults).void
      case _ => Task.unit
    }
  }

  def commitAsync(state: PartitionOffsetState): Future[Map[TopicPartition, OffsetAndMetadata]] = {
    val promise: Promise[Map[TopicPartition, OffsetAndMetadata]] = Promise[Map[TopicPartition, OffsetAndMetadata]]()

    if (state.nonEmpty) {
      object callback extends OffsetCommitCallback {
        override def onComplete(offsets: util.Map[TopicPartition, OffsetAndMetadata], exception: Exception): Unit = {
          logger.debug(s"commitAsync($offsets, $exception)")
          if (exception != null) {
            promise.tryFailure(exception)
          } else {
            promise.trySuccess(offsets.asScala.toMap)
          }
        }
      }
      logger.debug(s"commitAsync($state)")
      consumer.commitAsync(state.asTopicPartitionMapJava, callback)
    } else {
      logger.trace(s"NOT committing empty state")
      promise.trySuccess(Map.empty)
    }
    promise.future
  }

  private def swallow(thunk: => Unit) = {
    Try(thunk).map(_ => true)
  }

  def seekToBeginning(partition: Int) = swallow {
    logger.info(s"seekToBeginning(${partition})")
    topics.foreach { topic =>
      val tp = new TopicPartition(topic, partition)
      consumer.seekToBeginning(java.util.Collections.singletonList(tp))
    }
  }

  def seekToBeginning() = swallow {
    logger.info(s"seekToBeginning")
    topics.foreach { topic =>
      val topicPartitions = assignmentPartitions.map { partition =>
        new TopicPartition(topic, partition)
      }
      consumer.seekToBeginning(topicPartitions.asJava)
    }
  }
  def seekToEnd() = swallow {
    logger.info("seekToEndUnsafe")
    topics.foreach { topic =>
      val topicPartitions = assignmentPartitions.map { partition =>
        new TopicPartition(topic, partition)
      }
      consumer.seekToEnd(topicPartitions.asJava)
    }
  }

  def assignToTopics(): Try[Set[TopicPartition]] = {
    val pbt = partitionsByTopic()
    val allTopicPartitions = topics.flatMap { topic =>
      val topicPartitions = pbt.get(topic).map { partitions: List[KafkaPartitionInfo] =>
        partitions.map(_.asTopicPartition)
      }
      topicPartitions.getOrElse(Nil)
    }
    swallow(consumer.assign(allTopicPartitions.asJava)).map { _ =>
      allTopicPartitions
    }
  }

  def seekToOffset(offset: Long) = seekToCustom(_ => offset)

  def seekToCustom(computeOffset: KafkaPartitionInfo => Long) = swallow {

    val partitions = partitionsByTopic(true)
    partitions.collect {
      case (topic, partitions) =>
        partitions.foreach { pi: KafkaPartitionInfo =>
          val offset = computeOffset(pi)
          consumer.seek(new TopicPartition(topic, pi.partition), offset)
        }
    }
  }

  def seekTo(topicPartitionState: PartitionOffsetState) = swallow {
    logger.info(s"seekToUnsafe(${topicPartitionState})")
    for {
      topic               <- topics
      topicPartitions     <- topicPartitionState.offsetByPartitionByTopic.get(topic).toSeq
      (partition, offset) <- topicPartitions
    } yield {
      consumer.seek(new TopicPartition(topic, partition), offset)
    }
  }

  def positionsFor(partition: Int) = {
    val byTopic = topics.map { topic =>
      topic -> consumer.position(new TopicPartition(topic, partition))
    }
    byTopic.toMap
  }

  def committed(partition: Int): Map[String, OffsetAndMetadata] = {
    val byTopic = topics.map { topic =>
      topic -> consumer.committed(new TopicPartition(topic, partition))
    }
    byTopic.toMap
  }

  def assignmentPartitions: List[Int] = {
    assignments().map { tp =>
      require(topics.contains(tp.topic()), s"consumer for topics $topics has assignment on ${tp.topic()}")
      tp.partition()
    }
  }
  def assignments() = consumer.assignment().asScala.toList

  def status(verbose: Boolean): String = {
    val byTopic = partitionsByTopic()

    val topicStatuses = topics.map { topic =>
      byTopic.get(topic).fold(s"topic '${topic}' doesn't exist") { partitions =>
        val ourAssignments = {
          val all: List[Int] = assignmentPartitions
          val detail = if (verbose) {
            val committedStatus: Seq[Map[String, OffsetAndMetadata]] = all.map(committed)
            committedStatus.mkString("\n\tCommit status:\n\t", "\n\t", "\n")
          } else {
            ""
          }
          all.mkString(s"assigned to ${all.size}: [", ",", s"]$detail")
        }

        s"'$topic' status (one of ${topics.size} topics [${topics.mkString("\n\t", "\n\t", "\n\t")}])\ncurrently $ourAssignments\n${TopicStatus(topic, partitions).toString}"
      }
    }
    topicStatuses.mkString("\n")
  }

  /**
    * @return a scala-friendly data structure containing the commit status of the kafka cluster
    */
  def committedStatus(): List[CommittedStatus] = {
    val all: List[Int] = assignmentPartitions
    partitionsByTopic().collect {
      case (topic, kafkaPartitions) =>
        val weAreSubscribed: Boolean = topics.contains(topic)

        val partitionStats: List[(KafkaPartitionInfo, Boolean)] = kafkaPartitions.map { kpi =>
          val isAssigned = all.contains(kpi.partition)
          (kpi, isAssigned)
        }

        val commitStats: mutable.Map[TopicPartition, OffsetAndMetadata] = consumer.committed(partitionStats.map(_._1.asTopicPartition).toSet.asJava).asScala

        val pos = PartitionOffsetState.fromKafka(commitStats.toMap)
        CommittedStatus(topic, weAreSubscribed, partitionStats, pos)
    }.toList
  }

  def isClosed() = closed

  override def close(): Unit = {
    closed = true
    consumer.close()
    Schedulers.close(kafkaScheduler)
  }

  override def withConsumer[A](withConsumer: RichKafkaConsumer[K, V] => A): Future[A] = {
    val cmd = ExecOnConsumer[K, V, A](withConsumer)
    import cats.syntax.apply._
    val task = commandQueue.offer(cmd) *> execNext

    task
      .runToFuture(asyncScheduler)
      .flatMap { _ =>
        cmd.promise.future
      }(asyncScheduler)
  }
}

object RichKafkaConsumer extends StrictLogging {

  def asJavaDuration(d: Duration): time.Duration = {
    if (d.isFinite) {
      java.time.Duration.ofMillis(d.toMillis)
    } else {
      java.time.Duration.ofDays(Long.MaxValue)
    }
  }

  private[consumer] def byteArrayValues(rootConfig: Config, kafkaScheduler: Scheduler, ioSched: Scheduler) = {
    val keyDeserializer   = new org.apache.kafka.common.serialization.StringDeserializer
    val valueDeserializer = new org.apache.kafka.common.serialization.ByteArrayDeserializer
    apply(rootConfig, keyDeserializer, valueDeserializer, kafkaScheduler)(ioSched)
  }

  private[consumer] def apply[K, V](rootConfig: Config,
                                    keyDeserializer: Deserializer[K],
                                    valueDeserializer: Deserializer[V],
                                    kafkaScheduler: Scheduler = FixedScheduler().scheduler)(implicit ioSched: Scheduler): RichKafkaConsumer[K, V] = {

    import args4c.implicits._
    val consumerConfig = rootConfig.getConfig("kafka4m.consumer")
    val topics         = kafka4m.consumerTopics(rootConfig)

    val props: Properties = {
      val properties = kafka4m.util.Props.propertiesForConfig(consumerConfig)

      //
      // subscribe to our topic
      // .. properties.asScala.mkString()  is broken as it tries to cast things as strings, and some values are integers
      def propString = {
        val keys = properties.propertyNames.asScala
        keys
          .map { key =>
            s"$key : ${properties.getProperty(key.toString)}"
          }
          .mkString("\n\t", "\n\t", "\n\n")
      }

      logger.info(s"Creating consumer for '${topics.mkString(",")}', properties are:\n${propString}")
      properties
    }

    val consumer: KafkaConsumer[K, V] = new KafkaConsumer[K, V](props, keyDeserializer, valueDeserializer)
    val pollTimeout                   = rootConfig.asDuration("kafka4m.consumer.pollTimeout")

    val capacity: BufferCapacity = rootConfig.getInt("kafka4m.consumer.commandBufferCapacity") match {
      case n if n <= 0 => BufferCapacity.Unbounded()
      case n           => BufferCapacity.Bounded(n)
    }

    val queue = ConcurrentQueue.unsafe[Task, ExecOnConsumer[K, V, _]](capacity)

    val richConsumer = new RichKafkaConsumer(consumer, topics, pollTimeout, queue, kafkaScheduler, ioSched)

    if (consumerConfig.getBoolean("subscribeOnConnect")) {
      topics.foreach(t => richConsumer.subscribe(t))
    } else {
      logger.debug("subscribeOnConnect is false")
    }

    richConsumer
  }
}
