package io.taig.taigless.validation

import scala.collection.Factory
import scala.collection.immutable.HashMap

import cats.Applicative
import cats.data.NonEmptyList
import cats.kernel.CommutativeSemigroup
import cats.syntax.all._

final case class Errors[A, B](head: (A, NonEmptyList[B]), tail: HashMap[A, NonEmptyList[B]]) {
  def get(value: A): List[B] = if (value == head._1) head._2.toList else tail.get(value).fold(List.empty[B])(_.toList)

  def modifyFields[C](f: A => C): Errors[C, B] = {
    val result = collection.mutable.HashMap.empty[C, NonEmptyList[B]]

    toMap.foreach { case (key, messages) =>
      result.updateWith(f(key)) {
        case Some(current) => Some(current concatNel messages)
        case None          => Some(messages)
      }
    }

    Errors.unsafeFromMap(result.to(HashMap))
  }

  def modifyFieldsF[F[_], C](f: A => F[C])(implicit F: Applicative[F]): F[Errors[C, B]] =
    tail.toList
      .traverse { case (key, messages) => f(key).tupleRight(messages) }
      .map(values => Errors.unsafeFromMap(values.toMap))

  def toMap: Map[A, NonEmptyList[B]] = tail + head

  def toNel: NonEmptyList[(A, NonEmptyList[B])] = NonEmptyList(head, tail.toList)

  def to[C1](factory: Factory[(A, NonEmptyList[B]), C1]): C1 = factory.fromSpecific(toMap)
}

object Errors {
  def apply[A, B](head: (A, NonEmptyList[B]), tail: (A, NonEmptyList[B])*): Errors[A, B] =
    unsafeFromMap(HashMap(head) ++ tail.toMap)

  def one[A, B](field: A, message: B): Errors[A, B] = unsafeFromMap(HashMap(field -> NonEmptyList.one(message)))

  def of[A, B](field: A)(messages: NonEmptyList[B]): Errors[A, B] = unsafeFromMap(HashMap(field -> messages))

  def wrap[A]: WrapApply[A] = new WrapApply[A]

  final class WrapApply[A](val dummy: Boolean = true) extends AnyVal {
    def apply[B, C](field: B => A)(errors: Errors[B, C]): Errors[A, C] = errors.modifyFields(field)
  }

  def fromMap[A, B](values: Map[A, NonEmptyList[B]]): Option[Errors[A, B]] =
    Option.when(values.nonEmpty)(unsafeFromMap(values))

  def unsafeFromMap[A, B](values: Map[A, NonEmptyList[B]]): Errors[A, B] = Errors(values.head, values.tail.to(HashMap))

  implicit def semigroup[A, B]: CommutativeSemigroup[Errors[A, B]] = new CommutativeSemigroup[Errors[A, B]] {
    override def combine(x: Errors[A, B], y: Errors[A, B]): Errors[A, B] = {
      val result = x.to(collection.mutable.HashMap)
      y.toMap.foreach { case (key, messages) =>
        result.updateWith(key) {
          case Some(current) => Some(current concatNel messages)
          case None          => Some(messages)
        }
      }
      unsafeFromMap(result.to(HashMap))
    }
  }
}
