package io.joern.c2cpg.astcreation

import io.shiftleft.codepropertygraph.generated.ControlStructureTypes
import io.joern.x2cpg.Ast
import io.shiftleft.codepropertygraph.generated.nodes.ExpressionNew
import org.eclipse.cdt.core.dom.ast._
import org.eclipse.cdt.core.dom.ast.cpp._
import org.eclipse.cdt.core.dom.ast.gnu.IGNUASTGotoStatement
import org.eclipse.cdt.internal.core.dom.parser.c.CASTIfStatement
import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTIfStatement
import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTNamespaceAlias
import org.eclipse.cdt.internal.core.model.ASTStringUtil

trait AstForStatementsCreator { this: AstCreator =>

  import io.joern.c2cpg.astcreation.AstCreatorHelper.OptionSafeAst

  protected def astForBlockStatement(blockStmt: IASTCompoundStatement, order: Int = -1): Ast = {
    val node = blockNode(blockStmt, Defines.empty, registerType(Defines.voidTypeName)).order(order).argumentIndex(order)
    scope.pushNewScope(node)
    var currOrder = 1
    val childAsts = blockStmt.getStatements.flatMap { stmt =>
      val r = astsForStatement(stmt, currOrder)
      currOrder = currOrder + r.length
      r
    }
    scope.popScope()
    blockAst(node, childAsts.toList)
  }

  private def astsForDeclarationStatement(decl: IASTDeclarationStatement): Seq[Ast] =
    decl.getDeclaration match {
      case simplDecl: IASTSimpleDeclaration
          if simplDecl.getDeclarators.headOption.exists(_.isInstanceOf[IASTFunctionDeclarator]) =>
        Seq(astForFunctionDeclarator(simplDecl.getDeclarators.head.asInstanceOf[IASTFunctionDeclarator]))
      case simplDecl: IASTSimpleDeclaration =>
        val locals =
          simplDecl.getDeclarators.zipWithIndex.toList.map { case (d, i) => astForDeclarator(simplDecl, d, i) }
        val calls =
          simplDecl.getDeclarators.filter(_.getInitializer != null).toList.map { d =>
            astForInitializer(d, d.getInitializer)
          }
        locals ++ calls
      case s: ICPPASTStaticAssertDeclaration         => Seq(astForStaticAssert(s))
      case usingDeclaration: ICPPASTUsingDeclaration => handleUsingDeclaration(usingDeclaration)
      case alias: ICPPASTAliasDeclaration            => Seq(astForAliasDeclaration(alias))
      case func: IASTFunctionDefinition              => Seq(astForFunctionDefinition(func))
      case alias: CPPASTNamespaceAlias               => Seq(astForNamespaceAlias(alias))
      case asm: IASTASMDeclaration                   => Seq(astForASMDeclaration(asm))
      case _: ICPPASTUsingDirective                  => Seq.empty
      case decl                                      => Seq(astForNode(decl))
    }

  private def astForReturnStatement(ret: IASTReturnStatement): Ast = {
    val cpgReturn = returnNode(ret, nodeSignature(ret))
    val expr      = nullSafeAst(ret.getReturnValue)
    Ast(cpgReturn).withChild(expr).withArgEdge(cpgReturn, expr.root)
  }

  private def astForBreakStatement(br: IASTBreakStatement): Ast = {
    Ast(controlStructureNode(br, ControlStructureTypes.BREAK, nodeSignature(br)))
  }

  private def astForContinueStatement(cont: IASTContinueStatement): Ast = {
    Ast(controlStructureNode(cont, ControlStructureTypes.CONTINUE, nodeSignature(cont)))
  }

  private def astForGotoStatement(goto: IASTGotoStatement): Ast = {
    val code = s"goto ${ASTStringUtil.getSimpleName(goto.getName)};"
    Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code))
  }

  private def astsForGnuGotoStatement(goto: IGNUASTGotoStatement): Seq[Ast] = {
    // This is for GNU GOTO labels as values.
    // See: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html
    // For such GOTOs we cannot statically determine the target label. As a quick
    // hack we simply put edges to all labels found indicated by *. This might be an over-taint.
    val code     = s"goto *;"
    val gotoNode = Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code))
    val exprNode = nullSafeAst(goto.getLabelNameExpression)
    Seq(gotoNode, exprNode)
  }

  private def astsForLabelStatement(label: IASTLabelStatement): Seq[Ast] = {
    val cpgLabel    = newJumpTargetNode(label)
    val nestedStmts = nullSafeAst(label.getNestedStatement)
    Ast(cpgLabel) +: nestedStmts
  }

  private def astForDoStatement(doStmt: IASTDoStatement): Ast = {
    val code         = nodeSignature(doStmt)
    val doNode       = controlStructureNode(doStmt, ControlStructureTypes.DO, code)
    val conditionAst = astForConditionExpression(doStmt.getCondition)
    val bodyAst      = nullSafeAst(doStmt.getBody)
    controlStructureAst(doNode, Some(conditionAst), bodyAst, placeConditionLast = true)
  }

  private def astForSwitchStatement(switchStmt: IASTSwitchStatement): Ast = {
    val code         = s"switch(${nullSafeCode(switchStmt.getControllerExpression)})"
    val switchNode   = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, code)
    val conditionAst = astForConditionExpression(switchStmt.getControllerExpression)
    val stmtAsts     = nullSafeAst(switchStmt.getBody)
    controlStructureAst(switchNode, Some(conditionAst), stmtAsts)
  }

  private def astsForCaseStatement(caseStmt: IASTCaseStatement): Seq[Ast] = {
    val labelNode = newJumpTargetNode(caseStmt)
    val stmt      = astForConditionExpression(caseStmt.getExpression)
    Seq(Ast(labelNode), stmt)
  }

  private def astForDefaultStatement(caseStmt: IASTDefaultStatement): Ast = {
    Ast(newJumpTargetNode(caseStmt))
  }

  private def astForTryStatement(tryStmt: ICPPASTTryBlockStatement): Ast = {
    val cpgTry = controlStructureNode(tryStmt, ControlStructureTypes.TRY, "try")
    val body   = nullSafeAst(tryStmt.getTryBody)
    // All catches must have order 2 for correct control flow generation.
    // TODO fix this. Multiple siblings with the same order are invalid
    val catches = tryStmt.getCatchHandlers.flatMap { stmt =>
      astsForStatement(stmt.getCatchBody, 2)
    }.toIndexedSeq
    Ast(cpgTry).withChildren(body).withChildren(catches)
  }

  protected def astsForStatement(statement: IASTStatement, argIndex: Int = -1): Seq[Ast] = {
    val r = statement match {
      case expr: IASTExpressionStatement          => Seq(astForExpression(expr.getExpression))
      case block: IASTCompoundStatement           => Seq(astForBlockStatement(block, argIndex))
      case ifStmt: IASTIfStatement                => Seq(astForIf(ifStmt))
      case whileStmt: IASTWhileStatement          => Seq(astForWhile(whileStmt))
      case forStmt: IASTForStatement              => Seq(astForFor(forStmt))
      case forStmt: ICPPASTRangeBasedForStatement => Seq(astForRangedFor(forStmt))
      case doStmt: IASTDoStatement                => Seq(astForDoStatement(doStmt))
      case switchStmt: IASTSwitchStatement        => Seq(astForSwitchStatement(switchStmt))
      case ret: IASTReturnStatement               => Seq(astForReturnStatement(ret))
      case br: IASTBreakStatement                 => Seq(astForBreakStatement(br))
      case cont: IASTContinueStatement            => Seq(astForContinueStatement(cont))
      case goto: IASTGotoStatement                => Seq(astForGotoStatement(goto))
      case goto: IGNUASTGotoStatement             => astsForGnuGotoStatement(goto)
      case defStmt: IASTDefaultStatement          => Seq(astForDefaultStatement(defStmt))
      case tryStmt: ICPPASTTryBlockStatement      => Seq(astForTryStatement(tryStmt))
      case caseStmt: IASTCaseStatement            => astsForCaseStatement(caseStmt)
      case decl: IASTDeclarationStatement         => astsForDeclarationStatement(decl)
      case label: IASTLabelStatement              => astsForLabelStatement(label)
      case _: IASTNullStatement                   => Seq.empty
      case _                                      => Seq(astForNode(statement))
    }
    r.map(x => asChildOfMacroCall(statement, x))
  }

  private def astForConditionExpression(expr: IASTExpression, explicitArgumentIndex: Option[Int] = None): Ast = {
    val ast = expr match {
      case exprList: IASTExpressionList =>
        val compareAstBlock = blockNode(expr, Defines.empty, registerType(Defines.voidTypeName))
        scope.pushNewScope(compareAstBlock)
        val compareBlockAstChildren = exprList.getExpressions.toList.map(nullSafeAst)
        setArgumentIndices(compareBlockAstChildren)
        val compareBlockAst = blockAst(compareAstBlock, compareBlockAstChildren)
        scope.popScope()
        compareBlockAst
      case other =>
        nullSafeAst(other)
    }
    explicitArgumentIndex.foreach { i =>
      ast.root.foreach { case expr: ExpressionNew => expr.argumentIndex = i }
    }
    ast
  }

  private def astForFor(forStmt: IASTForStatement): Ast = {
    val codeInit = nullSafeCode(forStmt.getInitializerStatement)
    val codeCond = nullSafeCode(forStmt.getConditionExpression)
    val codeIter = nullSafeCode(forStmt.getIterationExpression)

    val code    = s"for ($codeInit$codeCond;$codeIter)"
    val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code)

    val initAstBlock = blockNode(forStmt, Defines.empty, registerType(Defines.voidTypeName))
    scope.pushNewScope(initAstBlock)
    val initAst = blockAst(initAstBlock, nullSafeAst(forStmt.getInitializerStatement, 1).toList)
    scope.popScope()

    val compareAst = astForConditionExpression(forStmt.getConditionExpression, Some(2))
    val updateAst  = nullSafeAst(forStmt.getIterationExpression, 3)
    val bodyAsts   = nullSafeAst(forStmt.getBody, 4)
    forAst(forNode, Seq(), Seq(initAst), Seq(compareAst), Seq(updateAst), bodyAsts)
  }

  private def astForRangedFor(forStmt: ICPPASTRangeBasedForStatement): Ast = {
    val codeDecl = nullSafeCode(forStmt.getDeclaration)
    val codeInit = nullSafeCode(forStmt.getInitializerClause)

    val code    = s"for ($codeDecl:$codeInit)"
    val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code)

    val initAst = astForNode(forStmt.getInitializerClause)
    val declAst = astsForDeclaration(forStmt.getDeclaration)
    val stmtAst = nullSafeAst(forStmt.getBody)
    controlStructureAst(forNode, None, Seq(initAst) ++ declAst ++ stmtAst)
  }

  private def astForWhile(whileStmt: IASTWhileStatement): Ast = {
    val code       = s"while (${nullSafeCode(whileStmt.getCondition)})"
    val compareAst = astForConditionExpression(whileStmt.getCondition)
    val bodyAst    = nullSafeAst(whileStmt.getBody)
    whileAst(Some(compareAst), bodyAst, Some(code))
  }

  private def astForIf(ifStmt: IASTIfStatement): Ast = {
    val (code, conditionAst) = ifStmt match {
      case s @ (_: CASTIfStatement | _: CPPASTIfStatement) if s.getConditionExpression != null =>
        val c          = s"if (${nullSafeCode(s.getConditionExpression)})"
        val compareAst = astForConditionExpression(s.getConditionExpression)
        (c, compareAst)
      case s: CPPASTIfStatement if s.getConditionExpression == null =>
        val c         = s"if (${nullSafeCode(s.getConditionDeclaration)})"
        val exprBlock = blockNode(s.getConditionDeclaration, Defines.empty, Defines.voidTypeName)
        scope.pushNewScope(exprBlock)
        val a = astsForDeclaration(s.getConditionDeclaration)
        setArgumentIndices(a)
        scope.popScope()
        (c, blockAst(exprBlock, a.toList))
    }

    val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, code)

    val thenAst = ifStmt.getThenClause match {
      case block: IASTCompoundStatement => astForBlockStatement(block)
      case other if other != null =>
        val thenBlock = blockNode(other, Defines.empty, Defines.voidTypeName)
        scope.pushNewScope(thenBlock)
        val a = astsForStatement(other)
        setArgumentIndices(a)
        scope.popScope()
        blockAst(thenBlock, a.toList)
      case _ => Ast()
    }

    val elseAst = ifStmt.getElseClause match {
      case block: IASTCompoundStatement =>
        val elseNode = controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else")
        val elseAst  = astForBlockStatement(block)
        Ast(elseNode).withChild(elseAst)
      case other if other != null =>
        val elseNode  = controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else")
        val elseBlock = blockNode(other, Defines.empty, Defines.voidTypeName)
        scope.pushNewScope(elseBlock)
        val a = astsForStatement(other)
        setArgumentIndices(a)
        scope.popScope()
        Ast(elseNode).withChild(blockAst(elseBlock, a.toList))
      case _ => Ast()
    }
    controlStructureAst(ifNode, Some(conditionAst), Seq(thenAst, elseAst))
  }
}
