/*
 * Decompiled with CFR 0.152.
 */
package io.codemodder.remediation.sqlinjection;

import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.body.BodyDeclaration;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.body.Parameter;
import com.github.javaparser.ast.expr.BinaryExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import io.codemodder.ast.ASTTransforms;
import io.codemodder.ast.LinearizedStringExpression;
import io.codemodder.remediation.sqlinjection.SQLParameterizer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.regex.Pattern;

public final class SQLTableInjectionFilterTransform {
    private static final Pattern regex = Pattern.compile(".*from +((\\\\)?\")?", 2);
    private static final String filterMethodName = "validateTableName";

    private SQLTableInjectionFilterTransform() {
    }

    private static boolean isExecuteCall(MethodCallExpr methodCallExpr) {
        return SQLParameterizer.isParameterizationCandidate(methodCallExpr);
    }

    private static boolean isPrepareStatementCall(MethodCallExpr methodCallExpr) {
        try {
            Predicate<MethodCallExpr> isPrepareStatementCallPredicate = call -> call.getNameAsString().equals("prepareStatement");
            Predicate<MethodCallExpr> hasSQLConnectionScope = n -> n.getScope().filter(s -> {
                try {
                    String resolvedType = s.calculateResolvedType().describe();
                    return "java.sql.Connection".equals(resolvedType);
                }
                catch (RuntimeException e) {
                    return false;
                }
            }).isPresent();
            Predicate<MethodCallExpr> isFirstArgumentNotSLE = n -> n.getArguments().getFirst().map(e -> !(e instanceof StringLiteralExpr)).orElse(false);
            return isPrepareStatementCallPredicate.and(hasSQLConnectionScope.and(isFirstArgumentNotSLE)).test(methodCallExpr);
        }
        catch (RuntimeException e) {
            return false;
        }
    }

    public static boolean matchCall(MethodCallExpr call) {
        return SQLTableInjectionFilterTransform.isPrepareStatementCall(call) || SQLTableInjectionFilterTransform.isExecuteCall(call);
    }

    public static boolean fix(MethodCallExpr call) {
        LinearizedStringExpression linearized = new LinearizedStringExpression(call.getArgument(0));
        List<Expression> injections = SQLTableInjectionFilterTransform.findTableInjections(linearized);
        if (!(injections = injections.stream().filter(e -> !e.isMethodCallExpr() || !e.asMethodCallExpr().getNameAsString().equals(filterMethodName)).toList()).isEmpty()) {
            SQLTableInjectionFilterTransform.fix(injections, linearized.getResolvedExpressionsMap());
            return true;
        }
        return false;
    }

    public static boolean findAndFix(MethodCallExpr call) {
        if (SQLTableInjectionFilterTransform.matchCall(call)) {
            return SQLTableInjectionFilterTransform.fix(call);
        }
        return false;
    }

    private static List<Expression> findTableInjections(LinearizedStringExpression linearized) {
        ArrayList<Expression> tableInjections = new ArrayList<Expression>();
        Iterator<Expression> it = linearized.getLinearized().iterator();
        while (it.hasNext()) {
            String value;
            Expression expr = it.next();
            if (!expr.isStringLiteralExpr() || !regex.matcher(value = expr.asStringLiteralExpr().getValue()).matches() || !it.hasNext()) continue;
            tableInjections.add(it.next());
        }
        tableInjections.removeIf(Expression::isStringLiteralExpr);
        return tableInjections;
    }

    private static void addFilterMethodIfMissing(ClassOrInterfaceDeclaration classDecl) {
        String method = " String validateTableName(final String tablename){\n  Pattern regex = Pattern.compile(\"[a-zA-Z0-9_]+(.[a-zA-Z0-9_]+)?\");\n  if (!regex.matcher(tablename).matches()){\n\t  throw new SecurityException(\"Supplied table name contains non-alphanumeric characters\");\n  }\n  return tablename;\n }\n";
        boolean filterMethodPresent = classDecl.findAll(MethodDeclaration.class).stream().anyMatch(md -> md.getNameAsString().equals(filterMethodName) && md.getParameters().size() == 1 && ((Parameter)md.getParameters().get(0)).getTypeAsString().equals("String"));
        if (!filterMethodPresent) {
            classDecl.addMember((BodyDeclaration)StaticJavaParser.parseMethodDeclaration((String)" String validateTableName(final String tablename){\n  Pattern regex = Pattern.compile(\"[a-zA-Z0-9_]+(.[a-zA-Z0-9_]+)?\");\n  if (!regex.matcher(tablename).matches()){\n\t  throw new SecurityException(\"Supplied table name contains non-alphanumeric characters\");\n  }\n  return tablename;\n }\n"));
        }
        ASTTransforms.addImportIfMissing((CompilationUnit)classDecl.findCompilationUnit().get(), "java.util.regex.Pattern");
    }

    private static void fix(List<Expression> injections, Map<Expression, Expression> resolutionMap) {
        injections.stream().map(e -> SQLTableInjectionFilterTransform.unresolve(e, resolutionMap)).forEach(SQLTableInjectionFilterTransform::wrapExpressionWithCall);
        Optional classDecl = injections.get(0).findAncestor(new Class[]{ClassOrInterfaceDeclaration.class});
        classDecl.ifPresent(SQLTableInjectionFilterTransform::addFilterMethodIfMissing);
    }

    private static Expression unresolve(Expression expr, Map<Expression, Expression> resolutionMap) {
        Expression unresolved = expr;
        while (resolutionMap.get(unresolved) != null) {
            unresolved = resolutionMap.get(unresolved);
        }
        return unresolved;
    }

    private static void wrapExpressionWithCall(Expression expr) {
        MethodCallExpr newCall = new MethodCallExpr(filterMethodName, new Expression[0]);
        expr.replace((Node)newCall);
        newCall.addArgument((Expression)new BinaryExpr(expr, (Expression)new StringLiteralExpr(""), BinaryExpr.Operator.PLUS));
    }
}

