package codacy.patterns

import codacy.base.Pattern

import scala.meta._
import scala.util.Try

case object Custom_Scala_HardCodedKey extends Pattern{

  override def apply(tree: Tree): Seq[Result] = {
    val trails = Seq(
      new JavaTrail(tree)
    )
    val finders = trails.map { trail => new Finder(tree, trail)}
      .filter(_.trailHasHints())
    val places = finders.flatMap(_.findHardCodedKeys())
    places.map { item => Result(message(item), item)}
  }

  private[this] case class Sink(val methodName: String, val position: Int, val hints: Seq[String])

  private[this] case class CallSite(val ctorname: Tree, val args: Seq[Term], val sink: Sink)

  private[this] trait FrameworkTrail {
    val sinks: Seq[Sink]
  }

  private[this] class JavaTrail(tree: Tree) extends FrameworkTrail {
    val sinks = Seq(
      Sink("DESKeySpec", 0,
        hints = Seq("javax.crypto.spec.DESKeySpec", "javax.crypto.spec._")),
      Sink("DESedeKeySpec", 0,
        hints = Seq("javax.crypto.spec.DESedeKeySpec", "javax.crypto.spec._")),
      Sink("KerberosKey", 0,
        hints = Seq("javax.security.auth.kerberos.KerberosKey",
                    "javax.security.auth.kerberos._")),
      Sink("SecretKeySpec", 0,
           hints = Seq("javax.crypto.spec.SecretKeySpec", "javax.crypto.spec._")),
      Sink("X509EncodedKeySpec", 0,
        hints = Seq("java.security.spec.X509EncodedKeySpec", "java.security.spec._")),
      Sink("PKCS8EncodedKeySpec", 0,
        hints = Seq("java.security.spec.PKCS8EncodedKeySpec", "java.security.spec._")),
      Sink("KerberosTicket", 3,
        hints = Seq("javax.security.auth.kerberos.KerberosTicket", "javax.security.auth.kerberos._")),
      Sink("DSAPublicKeySpec", 0,
        hints = Seq("java.security.spec.DSAPublicKeySpec", "java.security.spec._")),
      Sink("DSAPublicKeyImpl", 0,
        hints = Seq("sun.security.provider.DSAPublicKeyImpl", "sun.security.provider._"))
    )
  }

  private[this] class Finder(tree: Tree, trail: FrameworkTrail) {

    private[this] def determineScope(node: Tree): Option[Tree] = {
      node match {
        case t@q"..$mods def $name[..$tparams](...$paramss): $tpeopt = $expr" => Some(t)
        case t@q"..$mods class $name (...$params) extends $template" => Some(t)
        case source"..$stats" => Option.empty
        case item => item.parent.flatMap(determineScope)
      }
    }

    // Support shadowing of local variables. This should avoid some
    // edge cases where a shadowing name has the same name as the
    // hardcoded value.
    private[this] def hasConstructorChild(stats: Tree, child: Tree) = {
      stats.collect{
        case t@init"$tpe(...$exprss)" if tpe == child => true
      }.nonEmpty
    }

    private[this] def sameShadowedParameter(param: Term.Param, name: Term.Name): Boolean = {
      param.collect{
        case param"..$mods ${paramname: Name}: $atpeopt = $exprop" if paramname.toString() == name.toString() => true
      }.nonEmpty
    }

    private[this] def isStringParam(param: Name, atpeopt: Option[Type]): Boolean = {
      atpeopt.getOrElse("").toString == "String"
    }

    private[this] def findShadowingInBlock(stats: Tree, scope: Tree, name: Term.Name, callSite: CallSite) = {
      stats.collect{
        case q"(..${params: Seq[Term.Param]}) => $expr"
          if hasConstructorChild(stats, callSite.ctorname) &&
            params.exists(param => sameShadowedParameter(param, name)) => params
        case p"case ${pat: Pat} if $expropt => $expr"
          if hasConstructorChild(stats, callSite.ctorname) &&
            pat.collect{
              case p"${patname: Term.Name}" if patname.toString == name.toString => name
            }.nonEmpty => Seq(pat)
      }.flatten.nonEmpty || scope.collect{
        case param"..$mods ${param: Name}: $atpeopt = $exprop"
          if param.value == name.toString &&
            isStringParam(param, atpeopt) &&
            determineScope(param).exists{ paramScope => paramScope == scope} => true
      }.nonEmpty
    }

    private[this] def isShadowed(name: Term.Name, callSite: CallSite, scope: Tree): Boolean = {
      tree.collect{
        case t@q"{ ..$stats }" => findShadowingInBlock(t, scope, name, callSite)
        case t@q"{ ..case $casesnel }" => findShadowingInBlock(t, scope, name, callSite)
      }.toSeq.exists(identity)
    }

    private[this] def hardCodedScopeValues(tree: Tree): List[(scala.meta.Tree, Seq[scala.meta.Term.Name])] = {
      tree.collect {
        //valDefs
        case t@q"..${mods: Seq[Mod]} val ..${patsnel: Seq[Pat]}: $tpeopt = ${expr: Lit}.getBytes" if isStringLiteral(expr.value) =>
          //first parent should be the stats List[Tree], 2nd the class itself
          flattenedParent(t, patsnel)
        case t@q"..${mods: Seq[Mod]} val ..${patsnel: Seq[Pat]}: $tpeopt = ${expr: Term}" if isHardCodedArray(expr) =>
          flattenedParent(t, patsnel)
      }.flatten
    }

    private[this] val scopedValues = hardCodedScopeValues(tree)

    private[this] val pattern = "(?i)(pass|pwd|psw|secret|key|cipher|crypt|des|aes|mac|private|sign|cert).*".r

    private[this] def isSuspiciousName(name: Term.Name, callSite: CallSite): Boolean = {
      pattern.findFirstMatchIn(name.toString) match {
        case Some(matched) =>
          val isHardCodedSomewhere = scopedValues.exists{
            case (tree, names) =>
              names.map(_.toString).contains(name.toString)
          }
          val scope = determineScope(name)
          val shadowed = scope.fold(false){ case scope => isShadowed(name, callSite, scope)}
          isHardCodedSomewhere && !shadowed
        case _ => false
      }
    }

    private[this] def flattenedParent(t: Tree, patsnel: Seq[Pat]) = {
      t.parent.flatMap(_.parent).map { case classDef =>
        (classDef, patsnel.flatMap(_.collect { case p"${name: Term.Name}" => name }))
      }
    }

    private[this] def callSites(): Seq[(Tree, Seq[Term])] = {
      tree.collect {
        case q"new $expr(...$exprss)" => (expr, exprss.flatten)
        case q"new $expr(...$exprss) with ..$inits { $self => ..$stats }" => (expr, exprss.flatten)
        case q"$expr(...$exprss)" => (expr, exprss.flatten)
      }
    }

    private[this] def sameSinkConstructor(sink: Sink, ctorname: Tree): Boolean = {
      val subref = ctorname.collect{
        case init"${ctorname: Type}(..$_)" =>
          Try(ctorname.toString.split('.').last).getOrElse(ctorname)
        case q"$ref.$ctorname" =>
          ctorname.toString
      }.headOption.getOrElse("")

      sink.methodName == ctorname.toString() || subref == sink.methodName
    }

    private[this] def matchedCallSites(): Seq[CallSite] = {
      val currentCallSites = callSites()
      val sinks = currentCallSites.map{ case (ctorname, args) =>
        val matchedSinks: Seq[Sink] = trail.sinks.collect{
          case sink if sameSinkConstructor(sink, ctorname) => sink
        }
        (matchedSinks.headOption, ctorname, args, matchedSinks.nonEmpty)
      }
      val nonEmptySinks = sinks.filter{ case (_, _, _, item) => item}
      for {
        (optSink, ctorname, args, _) <- nonEmptySinks
        sink <- optSink
      } yield CallSite(ctorname, args, sink)
    }

    private[this] def treeHasImport(importElements: Seq[String]): Boolean = {
      tree.collect{
        case q"import ..$importersnel" if importersnel.exists{ case importer =>
          importElements.contains(importer.toString)} => true
      }.exists(identity)
    }

    private[this] def isHardCodedBytesConversion(arg: Term) = {
      arg.collect{
        case q"${lit: Lit}.getBytes" if isStringLiteral(lit.value) => true
      }.nonEmpty
    }

    private[this] def isHardCodedArray(arg: Tree) = {
      arg.collect{
        case init"${ctorname: Type}(..${ctornameexprssnel: Seq[Term]})"
          if ctorname.toString == "Array" => true
        case q"$expr(...$exprss)"=>
          expr match {
            case q"Array[Byte]" => true
            case q"Array" => true
            case _ => false
          }
      }.exists(identity(_))
    }

    private[this] def isStringLiteral(value: Any): Boolean = {
      value match {
        case t: String => true
        case _ => false
      }
    }

    private[this] def isHardCodedVariable(arg: Term, callSite: CallSite) = {
      arg.collect{
        case q"${name: Term.Name}" if isSuspiciousName(name, callSite) => true
      }.headOption.fold(false)(_ => true)
    }

    private[this] def isValueHardcoded(arg: Term) = {
      isHardCodedBytesConversion(arg) || isHardCodedArray(arg)
    }

    private[this] def hardCodedValue(callSite: CallSite): Option[Term] = {
      val position = callSite.sink.position
      val callSiteArg = callSite.args.lift(position)
      callSiteArg match {
        case Some(arg) if isValueHardcoded(arg) || isHardCodedVariable(arg, callSite) => Some(arg)
        // Can't find the particular argument.
        case _ => Option.empty
      }
    }

    def findHardCodedKeys(): Seq[Term] = {
      val foundCallSites = matchedCallSites()
      foundCallSites.flatMap{ callSite => hardCodedValue(callSite)}
    }

    def trailHasHints(): Boolean = {
      trail.sinks.exists{ sink => treeHasImport(sink.hints) }
    }
  }

  private[this] def message(tree: Tree) = Message("Hard coded cryptographic key")
}