package io.dyte.callstats.tests

import io.dyte.callstats.models.DataThroughputTestResults
import io.dyte.callstats.models.RTTStats
import io.dyte.callstats.utils.DependencyProvider
import io.dyte.webrtc.RtcConfiguration
import kotlin.math.roundToLong
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.datetime.Clock
import kotlinx.datetime.Instant

class DataThroughputTest(
  config: RtcConfiguration,
  provider: DependencyProvider,
  done: (Any) -> Unit,
  failed: (String) -> Unit,
  coroutineScope: CoroutineScope
) :
  CallTest(
    config,
    provider,
    isNotHostCandidate,
    done,
    failed,
    "DATA_THROUGHPUT_TEST",
    coroutineScope
  ) {

  private var samplePacket = ""
  private var stopSending = false
  private var testProgress = 0L
  private var maxNumberOfPacketsToSend = 0
  private var bytesToKeepBuffered = 0
  private var testDurationSeconds: Long = 5L
  private var sentPayloadBytes = 0
  private var startTime: Instant? = null
  private lateinit var lastBitrateMeasureTime: Instant
  private var receivedPayloadBytes = 0
  private var lastReceivedPayloadBytes = 0
  private var finalBitrateSum = 0L
  private var bitRateSamples = 0

  init {
    for (i in 0..1023) {
      samplePacket += "h"
    }

    maxNumberOfPacketsToSend = 1
    bytesToKeepBuffered = 1024 * maxNumberOfPacketsToSend
    testDurationSeconds = 4L
  }

  override val onOpenDC1 = {
    logger.logDebug("DataThroughputTest: On Open: ${ch1.readyState}")

    logger.logDebug("State change open, calling sending step")
    sendingStep()
  }

  override val onMessageDC2: suspend (ByteArray) -> Unit = { buffer ->
    coroutineScope.launch { onMessageReceived(buffer.decodeToString()) }
  }

  private suspend fun onMessageReceived(msg: String) {
    receivedPayloadBytes = msg.length
    val now = Clock.System.now()
    val duration = (now - lastBitrateMeasureTime).inWholeMilliseconds
    if (duration >= 1000) {
      var bitrate = (receivedPayloadBytes - lastReceivedPayloadBytes) / duration
      bitrate = (bitrate * 1000 * 8.toDouble()).roundToLong() / 1000
      finalBitrateSum += bitrate
      bitRateSamples += 1

      lastReceivedPayloadBytes = receivedPayloadBytes
      lastBitrateMeasureTime = now
    }
    if (stopSending && sentPayloadBytes == receivedPayloadBytes) {
      val rttStats = getRoundTripTime()
      val throughput = finalBitrateSum / bitRateSamples
      testComplete(
        DataThroughputTestResults(
          RTT = rttStats.rtt,
          backendRTT = rttStats.backendRTT,
          throughput = throughput
        )
      )
    }
  }

  @OptIn(DelicateCoroutinesApi::class)
  private fun sendingStep() {
    logger.logDebug("Entered sendingStep, startTime: $startTime")
    val now = Clock.System.now()
    if (startTime == null) {
      startTime = now
      lastBitrateMeasureTime = now
    }

    for (i in 0 until maxNumberOfPacketsToSend) {
      if (ch1.bufferedAmount >= bytesToKeepBuffered) {
        break
      }
      sentPayloadBytes += samplePacket.length

      val bb = samplePacket.encodeToByteArray()
      logger.logDebug("Sending data from sendingStep")
      ch1.send(bb)
    }

    val duration = (now - startTime!!).inWholeMilliseconds
    logger.logDebug("sendingStep: Duration: $duration")
    if (duration >= 1000 * testDurationSeconds) {
      stopSending = true
      testProgress = 100
    } else {
      testProgress = duration / (10L * testDurationSeconds)
      GlobalScope.launch {
        delay(1000)
        // TODO: Think why we can't have sendingStep here, and is this even correct?
        sendingStep()
      }
    }
  }

  private suspend fun getRoundTripTime(): RTTStats {
    val sendStats = pc1.getStats()
    val receiveStats = pc2.getStats()

    logger.logDebug("Send Stats: $sendStats")
    logger.logDebug("Receive Stats: $receiveStats")
    return RTTStats(rtt = 0L, backendRTT = 0L)
  }
}
