/*
 * Decompiled with CFR 0.152.
 */
package io.codemodder.remediation.ssrf;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.FieldAccessExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import io.codemodder.DependencyGAV;
import io.codemodder.ast.ASTTransforms;
import io.codemodder.remediation.RemediationStrategy;
import io.codemodder.remediation.SuccessOrReason;
import io.github.pixee.security.HostValidator;
import io.github.pixee.security.Urls;
import java.util.List;
import java.util.Optional;

public final class SSRFFixStrategy
implements RemediationStrategy {
    @Override
    public SuccessOrReason fix(CompilationUnit cu, Node node) {
        if (node instanceof MethodCallExpr) {
            MethodCallExpr mce = (MethodCallExpr)node;
            return this.hardenRT(cu, mce);
        }
        if (node instanceof ObjectCreationExpr) {
            ObjectCreationExpr oce = (ObjectCreationExpr)node;
            return this.harden(cu, oce);
        }
        return SuccessOrReason.reason("Not a method call or constructor");
    }

    private SuccessOrReason harden(CompilationUnit cu, ObjectCreationExpr newUrlCall) {
        NodeList arguments = newUrlCall.getArguments();
        MethodCallExpr safeCall = this.wrapInUrlsCreate(cu, (NodeList<Expression>)arguments);
        newUrlCall.replace((Node)safeCall);
        return SuccessOrReason.success(List.of(DependencyGAV.JAVA_SECURITY_TOOLKIT));
    }

    private MethodCallExpr wrapInUrlsCreate(CompilationUnit cu, NodeList<Expression> arguments) {
        ASTTransforms.addImportIfMissing(cu, Urls.class.getName());
        ASTTransforms.addImportIfMissing(cu, HostValidator.class.getName());
        FieldAccessExpr httpProtocolsExpr = new FieldAccessExpr();
        httpProtocolsExpr.setScope((Expression)new NameExpr(Urls.class.getSimpleName()));
        httpProtocolsExpr.setName("HTTP_PROTOCOLS");
        FieldAccessExpr denyCommonTargetsExpr = new FieldAccessExpr();
        denyCommonTargetsExpr.setScope((Expression)new NameExpr(HostValidator.class.getSimpleName()));
        denyCommonTargetsExpr.setName("DENY_COMMON_INFRASTRUCTURE_TARGETS");
        NodeList newArguments = new NodeList();
        newArguments.addAll(arguments);
        newArguments.add((Node)httpProtocolsExpr);
        newArguments.add((Node)denyCommonTargetsExpr);
        return new MethodCallExpr((Expression)new NameExpr(Urls.class.getSimpleName()), "create", newArguments);
    }

    private SuccessOrReason hardenRT(CompilationUnit cu, MethodCallExpr call) {
        Optional maybeFirstArg = call.getArguments().stream().findFirst();
        if (maybeFirstArg.isPresent()) {
            MethodCallExpr wrappedArg = new MethodCallExpr((Expression)this.wrapInUrlsCreate(cu, (NodeList<Expression>)new NodeList((Node[])new Expression[]{((Expression)maybeFirstArg.get()).clone()})), "toString");
            ((Expression)maybeFirstArg.get()).replace((Node)wrappedArg);
            return SuccessOrReason.success(List.of(DependencyGAV.JAVA_SECURITY_TOOLKIT));
        }
        return SuccessOrReason.reason("Could not find first argument");
    }
}

