/*
 * Decompiled with CFR 0.152.
 */
package io.codemodder.remediation.sqlinjection;

import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.AssignExpr;
import com.github.javaparser.ast.expr.BinaryExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.IntegerLiteralExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.NullLiteralExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.stmt.Statement;
import com.github.javaparser.ast.stmt.TryStmt;
import io.codemodder.Either;
import io.codemodder.ast.ASTTransforms;
import io.codemodder.ast.ASTs;
import io.codemodder.ast.ExpressionStmtVariableDeclaration;
import io.codemodder.ast.LocalScope;
import io.codemodder.ast.LocalVariableDeclaration;
import io.codemodder.ast.TryResourceDeclaration;
import io.codemodder.remediation.sqlinjection.QueryParameterizer;
import java.util.ArrayList;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

public final class SQLParameterizer {
    private static final String preparedStatementNamePrefix = "stmt";
    private static final String preparedStatementNamePrefixAlternative = "statement";
    private final MethodCallExpr executeCall;
    private CompilationUnit compilationUnit;
    private static final Set<String> fixableJdbcMethodNames = Set.of("executeQuery", "execute", "executeLargeUpdate", "executeUpdate");

    public SQLParameterizer(MethodCallExpr methodCallExpr) {
        this.executeCall = Objects.requireNonNull(methodCallExpr);
        this.compilationUnit = null;
    }

    public SQLParameterizer(MethodCallExpr methodCallExpr, CompilationUnit cu) {
        this.executeCall = Objects.requireNonNull(methodCallExpr);
        this.compilationUnit = cu;
    }

    public static boolean isParameterizationCandidate(MethodCallExpr methodCallExpr) {
        try {
            Predicate<MethodCallExpr> isExecute = SQLParameterizer::isSupportedJdbcMethodCall;
            Predicate<MethodCallExpr> hasScopeSQLStatement = n -> n.getScope().filter(s -> {
                try {
                    String resolvedType = s.calculateResolvedType().describe();
                    return "java.sql.Statement".equals(resolvedType);
                }
                catch (RuntimeException e) {
                    return false;
                }
            }).isPresent();
            Predicate<MethodCallExpr> isFirstArgumentNotSLE = n -> n.getArguments().getFirst().map(e -> !(e instanceof StringLiteralExpr)).orElse(false);
            return isExecute.and(hasScopeSQLStatement.and(isFirstArgumentNotSLE)).test(methodCallExpr);
        }
        catch (RuntimeException e) {
            return false;
        }
    }

    public static boolean isSupportedJdbcMethodCall(MethodCallExpr methodCall) {
        return fixableJdbcMethodNames.contains(methodCall.getNameAsString());
    }

    public static Set<String> fixableJdbcMethodNames() {
        return fixableJdbcMethodNames;
    }

    public static Expression resolveExpression(Expression expr) {
        return Optional.of(expr).map(e -> e instanceof NameExpr ? e.asNameExpr() : null).flatMap(n -> ASTs.findEarliestLocalDeclarationOf(n.getName())).map(s -> s instanceof LocalVariableDeclaration ? (LocalVariableDeclaration)s : null).filter(ASTs::isFinalOrNeverAssigned).flatMap(lvd -> lvd.getVariableDeclarator().getInitializer()).map(SQLParameterizer::resolveExpression).orElse(expr);
    }

    private Optional<MethodCallExpr> isConnectionCreateStatement(Expression expr) {
        Predicate<Expression> isConnection = e -> {
            try {
                return "java.sql.Connection".equals(e.calculateResolvedType().describe());
            }
            catch (RuntimeException ex) {
                return false;
            }
        };
        return Optional.of(expr).map(e -> e instanceof MethodCallExpr ? expr.asMethodCallExpr() : null).filter(mce -> mce.getScope().filter(isConnection).isPresent() && mce.getNameAsString().equals("createStatement"));
    }

    private Optional<MethodCallExpr> validateExecuteCall(MethodCallExpr executeCall) {
        MethodCallExpr methodCall = executeCall;
        Optional<Object> maybeCall = Optional.of(methodCall);
        while (maybeCall.isPresent()) {
            maybeCall = maybeCall.flatMap(ASTs::isScopeInMethodCall);
            methodCall = maybeCall.orElse(methodCall);
        }
        Predicate<MethodCallExpr> isLocalInitExpr = call -> ASTs.isInitExpr((Expression)call).flatMap(LocalVariableDeclaration::fromVariableDeclarator).isPresent();
        Predicate<MethodCallExpr> isAssigned = call -> ASTs.isAssigned((Expression)call).isPresent();
        Predicate<MethodCallExpr> isReturned = call -> ASTs.isReturnExpr((Expression)call).isPresent();
        Predicate<MethodCallExpr> isCall = call -> call.getParentNode().filter(p -> p instanceof ExpressionStmt).isPresent();
        Predicate<MethodCallExpr> isFirstTryResource = call -> ASTs.isInitExpr((Expression)executeCall).flatMap(ASTs::isResource).flatMap(pair -> ((TryStmt)pair.getValue0()).getResources().getFirst().filter(r -> r == pair.getValue1())).isPresent();
        if (isLocalInitExpr.or(isAssigned).or(isReturned).or(isCall).or(isFirstTryResource).test(executeCall)) {
            return Optional.of(executeCall);
        }
        return Optional.empty();
    }

    private Optional<Either<MethodCallExpr, Either<AssignExpr, LocalVariableDeclaration>>> findStatementCreationExpr(MethodCallExpr executeCall) {
        Optional<Either<MethodCallExpr, Either<AssignExpr, LocalVariableDeclaration>>> maybeImmediate = executeCall.getScope().flatMap(this::isConnectionCreateStatement).map(Either::left);
        if (maybeImmediate.isPresent()) {
            return maybeImmediate;
        }
        Optional<List> maybeLVD = executeCall.getScope().map(expr -> expr instanceof NameExpr ? expr.asNameExpr() : null).flatMap(ne -> ASTs.findEarliestLocalVariableDeclarationOf((Node)ne, ne.getNameAsString()));
        Optional<AssignExpr> maybeSingleAssigned = maybeLVD.map(lvd -> ASTs.findAllAssignments(lvd).limit(2L).toList()).filter(allAssignments -> allAssignments.size() == 1).map(allAssignments -> (AssignExpr)allAssignments.get(0)).filter(assign -> assign.getTarget().isNameExpr()).filter(assign -> this.isConnectionCreateStatement(assign.getValue()).isPresent());
        if (maybeSingleAssigned.isPresent()) {
            return maybeSingleAssigned.map(a -> Either.right(Either.left(a)));
        }
        Optional<LocalVariableDeclaration> maybeInitExpr = maybeLVD.filter(lvd -> lvd.getVariableDeclarator().getInitializer().map(this::isConnectionCreateStatement).isPresent());
        return maybeInitExpr.map(init -> Either.right(Either.right(init)));
    }

    private Optional<Either<MethodCallExpr, Either<AssignExpr, LocalVariableDeclaration>>> validateStatementCreationExpr(Either<MethodCallExpr, Either<AssignExpr, LocalVariableDeclaration>> stmtObject) {
        Optional<LocalVariableDeclaration> maybelvd;
        if (stmtObject.isRight() && stmtObject.getRight().isRight() && !this.canChangeTypes(stmtObject.getRight().getRight())) {
            return Optional.empty();
        }
        if (stmtObject.isRight() && (stmtObject.getRight().isLeft() ? (maybelvd = ASTs.findEarliestLocalVariableDeclarationOf((Node)stmtObject.getRight().getLeft(), stmtObject.getRight().getLeft().getTarget().asNameExpr().getNameAsString()).filter(lvd -> lvd instanceof ExpressionStmtVariableDeclaration)).isEmpty() : stmtObject.getRight().getRight() instanceof TryResourceDeclaration && !this.validateTryResource((TryResourceDeclaration)stmtObject.getRight().getRight(), this.executeCall))) {
            return Optional.empty();
        }
        return Optional.of(stmtObject);
    }

    private boolean canChangeTypes(LocalVariableDeclaration localDeclaration) {
        Stream allNameExpr = localDeclaration.getScope().stream().flatMap(n -> n.findAll(NameExpr.class, ne -> ne.getNameAsString().equals(localDeclaration.getName())).stream());
        return allNameExpr.allMatch(ne -> ASTs.isScopeInMethodCall((Expression)ne).isPresent());
    }

    private boolean validateTryResource(TryResourceDeclaration stmtObject, MethodCallExpr executeCall) {
        Optional<Expression> maybeLastResource = stmtObject.getStatement().getResources().getLast().filter(last -> last == stmtObject.getVariableDeclarationExpr());
        if (maybeLastResource.isPresent() && stmtObject.getStatement().getTryBlock().getStatements().getFirst().filter(first -> ASTs.findParentStatementFrom((Node)executeCall).filter(s -> s == first).isPresent()).isPresent()) {
            return true;
        }
        Optional<TryResourceDeclaration> maybeInit = ASTs.isInitExpr((Expression)executeCall).flatMap(LocalVariableDeclaration::fromVariableDeclarator).map(lvd -> lvd instanceof TryResourceDeclaration ? (TryResourceDeclaration)lvd : null).filter(trd -> trd.getStatement() == stmtObject.getStatement());
        if (maybeInit.isPresent()) {
            int stmtObjectIndex = stmtObject.getStatement().getResources().indexOf((Object)stmtObject.getVariableDeclarationExpr());
            int executeIndex = stmtObject.getStatement().getResources().indexOf((Object)maybeInit.get().getVariableDeclarationExpr());
            return Math.abs(executeIndex - stmtObjectIndex) == 1;
        }
        return false;
    }

    private String generateNameWithSuffix(String name, Node start) {
        String actualName = preparedStatementNamePrefix;
        Optional<Node> maybeName = ASTs.findNonCallableSimpleNameSource(start, actualName);
        if (maybeName.isPresent() && (maybeName = ASTs.findNonCallableSimpleNameSource(start, actualName = preparedStatementNamePrefixAlternative)).isPresent()) {
            actualName = preparedStatementNamePrefix;
        }
        int count = 0;
        Object nameWithSuffix = actualName;
        while (maybeName.isPresent()) {
            nameWithSuffix = actualName + ++count;
            maybeName = ASTs.findNonCallableSimpleNameSource(start, (String)nameWithSuffix);
        }
        return count == 0 ? actualName : nameWithSuffix;
    }

    private List<Expression> fixInjections(List<Deque<Expression>> injections, Map<Expression, Expression> resolvedMap) {
        ArrayList<Expression> combinedExpressions = new ArrayList<Expression>();
        for (Deque<Expression> injection : injections) {
            BinaryExpr newCombined;
            Expression start = injection.removeFirst();
            String startString = start.asStringLiteralExpr().getValue();
            StringBuilder builder = new StringBuilder(startString);
            int lastQuoteIndex = startString.lastIndexOf(39) + 1;
            String prepend = startString.substring(lastQuoteIndex);
            builder.replace(lastQuoteIndex - 1, startString.length(), "?");
            start.asStringLiteralExpr().setValue(builder.toString());
            Expression end = injection.removeLast();
            String endString = end.asStringLiteralExpr().getValue();
            int firstQuoteIndex = endString.indexOf(39);
            String newEnd = end.asStringLiteralExpr().getValue().substring(firstQuoteIndex + 1);
            String append = endString.substring(0, firstQuoteIndex);
            end.asStringLiteralExpr().setValue(newEnd);
            Expression combined = this.buildParameter(injection, resolvedMap);
            if (prepend != "") {
                newCombined = new BinaryExpr((Expression)new StringLiteralExpr(prepend), combined, BinaryExpr.Operator.PLUS);
                combined = newCombined;
            }
            if (append != "") {
                newCombined = new BinaryExpr(combined, (Expression)new StringLiteralExpr(append), BinaryExpr.Operator.PLUS);
                combined = newCombined;
            }
            combinedExpressions.add(combined);
        }
        return combinedExpressions;
    }

    private Expression unresolve(Expression expr, Map<Expression, Expression> resolutionMap) {
        Expression unresolved = expr;
        while (resolutionMap.get(unresolved) != null) {
            unresolved = resolutionMap.get(unresolved);
        }
        return unresolved;
    }

    private Expression buildParameter(Deque<Expression> injectionExpressions, Map<Expression, Expression> resolutionMap) {
        Iterator<Expression> it = injectionExpressions.iterator();
        Expression combined = it.next();
        boolean atLeastOneString = false;
        try {
            atLeastOneString = "java.lang.String".equals(combined.calculateResolvedType().describe());
        }
        catch (Exception exception) {
            // empty catch block
        }
        this.unresolve(combined, resolutionMap).replace((Node)new StringLiteralExpr(""));
        while (it.hasNext()) {
            Expression expr = it.next();
            try {
                if (!atLeastOneString && "java.lang.String".equals(expr.calculateResolvedType().describe())) {
                    atLeastOneString = true;
                }
            }
            catch (Exception exception) {
                // empty catch block
            }
            this.unresolve(expr, resolutionMap).replace((Node)new StringLiteralExpr(""));
            combined = new BinaryExpr(combined, expr, BinaryExpr.Operator.PLUS);
        }
        if (atLeastOneString) {
            return combined;
        }
        return new BinaryExpr(combined, (Expression)new StringLiteralExpr(""), BinaryExpr.Operator.PLUS);
    }

    private MethodCallExpr fix(Either<MethodCallExpr, Either<AssignExpr, LocalVariableDeclaration>> stmtCreation, QueryParameterizer queryParameterizer, MethodCallExpr executeCall) {
        MethodCallExpr pstmtCreation;
        Statement executeStmt = ASTs.findParentStatementFrom((Node)executeCall).get();
        if (stmtCreation.isRight() && stmtCreation.getRight().isRight() && executeStmt == stmtCreation.getRight().getRight().getStatement()) {
            int stmtObjectIndex = stmtCreation.getRight().getRight().getStatement().asTryStmt().getResources().indexOf((Object)stmtCreation.getRight().getRight().getVariableDeclarationExpr());
            executeStmt = ASTTransforms.splitResources(stmtCreation.getRight().getRight().getStatement().asTryStmt(), stmtObjectIndex).getTryBlock().getStatement(0);
        }
        String stmtName = stmtCreation.ifLeftOrElseGet(mce -> this.generateNameWithSuffix(preparedStatementNamePrefix, (Node)mce), assignOrLVD -> assignOrLVD.ifLeftOrElseGet(a -> a.getTarget().asNameExpr().getNameAsString(), lvd -> lvd.getName()));
        List<Expression> combinedExpressions = this.fixInjections(queryParameterizer.getInjections(), queryParameterizer.getLinearizedQuery().getResolvedExpressionsMap());
        Statement topStatement = executeStmt;
        for (int i = combinedExpressions.size() - 1; i >= 0; --i) {
            Expression expr2 = combinedExpressions.get(i);
            ExpressionStmt setStmt = null;
            setStmt = new ExpressionStmt((Expression)new MethodCallExpr((Expression)new NameExpr(stmtName), "setString", new NodeList((Node[])new Expression[]{new IntegerLiteralExpr(String.valueOf(i + 1)), expr2})));
            ASTTransforms.addStatementBeforeStatement(topStatement, (Statement)setStmt);
            topStatement = setStmt;
        }
        ASTTransforms.addImportIfMissing(this.compilationUnit, "java.sql.PreparedStatement");
        NodeList args = new NodeList();
        args.addFirst((Node)queryParameterizer.getRoot());
        args.addAll(stmtCreation.ifLeftOrElseGet(mce -> mce.getArguments(), assignOrLVD -> assignOrLVD.ifLeftOrElseGet(a -> a.getValue().asMethodCallExpr().getArguments(), lvd -> ((Expression)lvd.getVariableDeclarator().getInitializer().get()).asMethodCallExpr().getArguments())));
        executeCall.setName("execute");
        executeCall.setScope((Expression)new NameExpr(stmtName));
        executeCall.setArguments(new NodeList());
        if (stmtCreation.isLeft()) {
            pstmtCreation = new MethodCallExpr((Expression)stmtCreation.getLeft().getScope().get(), "prepareStatement", args);
            ExpressionStmt pstmtCreationStmt = new ExpressionStmt((Expression)new VariableDeclarationExpr(new VariableDeclarator(StaticJavaParser.parseType((String)"PreparedStatement"), stmtName, (Expression)pstmtCreation)));
            ASTTransforms.addStatementBeforeStatement(topStatement, (Statement)pstmtCreationStmt);
        } else {
            Either<AssignExpr, LocalVariableDeclaration> assignOrLVD2 = stmtCreation.getRight();
            if (assignOrLVD2.isLeft()) {
                pstmtCreation = assignOrLVD2.getLeft().getValue().asMethodCallExpr();
                pstmtCreation.setArguments(args);
                pstmtCreation.setName("prepareStatement");
                assignOrLVD2.getLeft().setValue(StaticJavaParser.parseExpression((String)"a"));
                assignOrLVD2.getLeft().setValue((Expression)pstmtCreation);
                Optional<LocalVariableDeclaration> maybeLVD = ASTs.findEarliestLocalVariableDeclarationOf((Node)assignOrLVD2.getLeft().getTarget(), assignOrLVD2.getLeft().getTarget().asNameExpr().getNameAsString());
                if (maybeLVD.isPresent()) {
                    VariableDeclarator vd = maybeLVD.get().getVariableDeclarator();
                    vd.setInitializer((Expression)new NullLiteralExpr());
                    vd.setType(StaticJavaParser.parseType((String)"PreparedStatement"));
                }
            } else {
                assignOrLVD2.getRight().getVariableDeclarator().setType(StaticJavaParser.parseType((String)"PreparedStatement"));
                assignOrLVD2.getRight().getVariableDeclarator().getInitializer().ifPresent(expr -> expr.asMethodCallExpr().setName("prepareStatement"));
                assignOrLVD2.getRight().getVariableDeclarator().getInitializer().ifPresent(expr -> expr.asMethodCallExpr().setArguments(args));
                pstmtCreation = ((Expression)assignOrLVD2.getRight().getVariableDeclarator().getInitializer().get()).asMethodCallExpr();
            }
        }
        return pstmtCreation;
    }

    private boolean resolvedInScope(Either<AssignExpr, LocalVariableDeclaration> assignOrLVD, Expression expr) {
        if (assignOrLVD.isLeft()) {
            LocalScope scope = LocalScope.fromAssignExpression(assignOrLVD.getLeft());
            if (scope.stream().findAny().isEmpty()) {
                return true;
            }
            return scope.inScope((Node)expr);
        }
        return assignOrLVD.getRight().getScope().inScope((Node)expr);
    }

    private boolean assignedOrDefinedInScope(NameExpr name, Either<AssignExpr, LocalVariableDeclaration> assignOrLVD) {
        LocalScope scope = assignOrLVD.ifLeftOrElseGet(a -> LocalScope.fromAssignExpression(a), lvd -> lvd.getScope());
        if (scope.stream().findAny().isEmpty()) {
            return true;
        }
        Stream assignmentsInScope = scope.stream().flatMap(node -> node instanceof AssignExpr ? Stream.of((AssignExpr)node) : Stream.empty());
        boolean assignedInScope = assignmentsInScope.flatMap(aexpr -> ASTs.hasNamedTarget(aexpr).stream()).anyMatch(nexpr -> nexpr.getNameAsString() == name.getNameAsString());
        boolean definedInScope = ASTs.findNonCallableSimpleNameSource(name.getName()).filter(source -> scope.inScope((Node)source)).isPresent();
        return assignedInScope || definedInScope;
    }

    public Optional<MethodCallExpr> checkAndFix() {
        Optional stmtObject;
        if (!this.executeCall.findCompilationUnit().isPresent()) {
            return Optional.empty();
        }
        this.compilationUnit = (CompilationUnit)this.executeCall.findCompilationUnit().get();
        if (SQLParameterizer.isParameterizationCandidate(this.executeCall) && this.validateExecuteCall(this.executeCall).isPresent() && (stmtObject = this.findStatementCreationExpr(this.executeCall).flatMap(this::validateStatementCreationExpr)).isPresent()) {
            if (this.executeCall.getArguments().isEmpty()) {
                return Optional.empty();
            }
            QueryParameterizer queryp = new QueryParameterizer(this.executeCall.getArgument(0));
            boolean resolvedInScope = ((Either)stmtObject.get()).ifLeftOrElseGet(mcd -> false, assignOrLVD -> queryp.getLinearizedQuery().getResolvedExpressionsMap().keySet().stream().anyMatch(expr -> this.resolvedInScope((Either<AssignExpr, LocalVariableDeclaration>)assignOrLVD, (Expression)expr)));
            boolean nameInScope = ((Either)stmtObject.get()).ifLeftOrElseGet(mcd -> false, assignOrLVD -> queryp.getLinearizedQuery().getLinearized().stream().filter(expr -> expr.isNameExpr()).map(expr -> expr.asNameExpr()).anyMatch(name -> this.assignedOrDefinedInScope((NameExpr)name, (Either<AssignExpr, LocalVariableDeclaration>)assignOrLVD)));
            if (queryp.getInjections().isEmpty() || resolvedInScope || nameInScope) {
                return Optional.empty();
            }
            return Optional.of(this.fix((Either)stmtObject.get(), queryp, this.executeCall));
        }
        return Optional.empty();
    }
}

