package io.codemodder.plugins.llm;

import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.contrastsecurity.sarif.Location;
import com.contrastsecurity.sarif.Region;
import com.contrastsecurity.sarif.Result;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.github.difflib.DiffUtils;
import com.github.difflib.patch.Patch;
import io.codemodder.CodemodChange;
import io.codemodder.CodemodFileScanningResult;
import io.codemodder.CodemodInvocationContext;
import io.codemodder.RuleSarif;
import io.codemodder.SarifPluginRawFileChanger;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/codemodder/plugins/llm/SarifToLLMForMultiOutcomeCodemod.class */
public abstract class SarifToLLMForMultiOutcomeCodemod extends SarifPluginRawFileChanger {
    private static final Logger logger = LoggerFactory.getLogger(SarifToLLMForMultiOutcomeCodemod.class);
    private final OpenAIService openAI;
    private final List<LLMRemediationOutcome> remediationOutcomes;
    private final Model categorizationModel;
    private final Model codeChangingModel;
    private static final String SYSTEM_MESSAGE_TEMPLATE = "You are a security analyst bot. You are helping analyze code to assess its risk to a specific security threat. Your code change recommendations are safe and accurate.\n%s\n";
    private static final String CATEGORIZE_CODE_USER_MESSAGE_TEMPLATE = "Analyze ONLY line %s, column %s, and discern which \"outcome\" best describes the code. You should save your categorization analysis. You MUST ignore any other file contents, even if they look like they have issues.\nHere are the possible outcomes:\n%s\n\nReturn a JSON object as a response with the following keys in this order:\n  - analysis: A detailed analysis of how the analysis arrived at the outcome\n  - outcomeKey: The category of the analysis, or empty if the analysis could not be categorized\n--- %s\n%s\n";
    private static final String CHANGE_CODE_USER_MESSAGE_TEMPLATE = "The tool has cited the following location for you to analyze:\n%s\nDecide which \"outcome\" you want to place it in. Then, if that outcome requires code changes, make the changes as described in the Code Change Directions and save them. Here are the possible outcomes:\n%s\nPick which outcome best describes the code. If you are making code changes, you MUST make the MINIMUM number of changes necessary to fix the issue.\n- Each change MUST be syntactically correct.\n- DO NOT change the file's formatting or comments.\n- Create a diff patch for the changed file if and only if any of the outcomes require code changes.\n- The patch must use the unified format with a header. Include the diff patch and a summary of the changes with your analysis.\nIf you the outcome says you should suppress a Semgrep finding in the code, insert a comment above it and put `// nosemgrep: <ruleid>`\nSave your categorization and code change analysis when you're done.\n\nReturn a JSON object as a response with the following keys in this order:\n  - outcomeKey: The outcome key associated with this particular result location\n  - fixDescription: A short description of the code change. Required only if the file needs a change.\n  - codeChange: A diff patch in unified format. Required if any of the outcome keys indicate a change.\n  - line: The line in the file to which this analysis is related\n  - column: The column to which this analysis is related\n--- %s\n%s\n";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/codemodder/plugins/llm/SarifToLLMForMultiOutcomeCodemod$CategorizeResponse.class */
    public static class CategorizeResponse {

        @JsonPropertyDescription("A detailed analysis of how the analysis arrived at the outcome")
        @JsonProperty(required = true)
        private String analysis;

        @JsonPropertyDescription("The category of the analysis, or empty if the analysis could not categorized")
        @JsonProperty(required = true)
        private String outcomeKey;

        public CategorizeResponse() {
        }

        private CategorizeResponse(String str, String str2) {
            this.analysis = str;
            this.outcomeKey = str2;
        }

        public String getAnalysis() {
            return this.analysis;
        }

        public String getOutcomeKey() {
            return this.outcomeKey;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/codemodder/plugins/llm/SarifToLLMForMultiOutcomeCodemod$CodeChangeResponse.class */
    public static final class CodeChangeResponse {

        @JsonPropertyDescription("The code change a diff patch in unified format. Required if any of the outcome keys indicate a change.")
        private String codeChange;

        @JsonPropertyDescription("The line in the file to which this analysis is related")
        private int line;

        @JsonPropertyDescription("The column to which this analysis is related")
        private int column;

        @JsonPropertyDescription("The outcome key associated with this particular result location")
        private String outcomeKey;

        @JsonPropertyDescription("A short description of the code change. Required only if the file needs a change.")
        private String fixDescription;

        CodeChangeResponse() {
        }

        public String getFixDescription() {
            return this.fixDescription;
        }

        public String getOutcomeKey() {
            return this.outcomeKey;
        }

        public int getLine() {
            return this.line;
        }

        public int getColumn() {
            return this.column;
        }

        public String getCodeChange() {
            return this.codeChange;
        }
    }

    protected SarifToLLMForMultiOutcomeCodemod(RuleSarif ruleSarif, OpenAIService openAIService, List<LLMRemediationOutcome> list) {
        this(ruleSarif, openAIService, list, StandardModel.GPT_4O_2024_05_13, StandardModel.GPT_4_TURBO_2024_04_09);
    }

    protected SarifToLLMForMultiOutcomeCodemod(RuleSarif ruleSarif, OpenAIService openAIService, List<LLMRemediationOutcome> list, Model model, Model model2) {
        super(ruleSarif);
        this.openAI = (OpenAIService) Objects.requireNonNull(openAIService);
        this.remediationOutcomes = (List) Objects.requireNonNull(list);
        if (list.size() < 2) {
            throw new IllegalArgumentException("must have 2+ remediation outcome");
        }
        this.categorizationModel = (Model) Objects.requireNonNull(model);
        this.codeChangingModel = (Model) Objects.requireNonNull(model2);
    }

    public CodemodFileScanningResult onFileFound(CodemodInvocationContext codemodInvocationContext, List<Result> list) {
        logger.info("processing: {}", codemodInvocationContext.path());
        ArrayList arrayList = new ArrayList();
        Iterator<Result> it = list.iterator();
        while (it.hasNext()) {
            Optional<CodemodChange> processResult = processResult(codemodInvocationContext, it.next());
            Objects.requireNonNull(arrayList);
            processResult.ifPresent((v1) -> {
                r1.add(v1);
            });
        }
        return CodemodFileScanningResult.withOnlyChanges(List.copyOf(arrayList));
    }

    private Optional<CodemodChange> processResult(CodemodInvocationContext codemodInvocationContext, Result result) {
        if (estimatedToExceedContextWindow(codemodInvocationContext)) {
            logger.debug("code too long: {}", codemodInvocationContext.path());
            return Optional.empty();
        }
        try {
            FileDescription from = FileDescription.from(codemodInvocationContext.path());
            CategorizeResponse categorize = categorize(from, result);
            String outcomeKey = categorize.getOutcomeKey();
            logger.debug("outcomeKey: {}", outcomeKey);
            logger.debug("analysis: {}", categorize.getAnalysis());
            if (outcomeKey == null || outcomeKey.isBlank()) {
                logger.debug("unable to determine outcome");
                return Optional.empty();
            }
            Optional<LLMRemediationOutcome> findFirst = this.remediationOutcomes.stream().filter(lLMRemediationOutcome -> {
                return lLMRemediationOutcome.key().equals(categorize.outcomeKey);
            }).findFirst();
            if (findFirst.isEmpty()) {
                logger.debug("unable to find outcome for key: {}", categorize.outcomeKey);
                return Optional.empty();
            }
            LLMRemediationOutcome lLMRemediationOutcome2 = findFirst.get();
            logger.debug("outcomeKey: {}", lLMRemediationOutcome2.key());
            logger.debug("description: {}", lLMRemediationOutcome2.description());
            if (!lLMRemediationOutcome2.shouldApplyCodeChanges()) {
                logger.debug("Matched outcome suggests there should be no code changes");
                return Optional.empty();
            }
            CodeChangeResponse changeCode = changeCode(from, result);
            logger.debug("outcome: {}", changeCode.outcomeKey);
            logger.debug("analysis: {}", changeCode.codeChange);
            if (changeCode.outcomeKey == null || outcomeKey.isEmpty()) {
                logger.debug("No outcomes detected");
                return Optional.empty();
            }
            if (!this.remediationOutcomes.stream().filter((v0) -> {
                return v0.shouldApplyCodeChanges();
            }).map((v0) -> {
                return v0.key();
            }).toList().contains(changeCode.outcomeKey)) {
                logger.debug("On second analysis, outcomes require no code changes");
                return Optional.empty();
            }
            String str = changeCode.codeChange;
            if (str == null || str.isEmpty()) {
                logger.info("unable to fix because diff not present: {}", codemodInvocationContext.path());
                return Optional.empty();
            }
            List<String> applyDiff = LLMDiffs.applyDiff(from.getLines(), str);
            Patch diff = DiffUtils.diff(from.getLines(), applyDiff);
            if (diff.getDeltas().isEmpty()) {
                logger.error("empty patch: {}", diff);
                return Optional.empty();
            }
            try {
                Files.writeString(codemodInvocationContext.path(), String.join(from.getLineSeparator(), applyDiff), from.getCharset(), new OpenOption[0]);
                return Optional.of(createCodemodChange(result, changeCode.line, changeCode.fixDescription));
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        } catch (IOException e2) {
            logger.error("failed to process: {}", codemodInvocationContext.path(), e2);
            throw new UncheckedIOException(e2);
        } catch (Exception e3) {
            logger.error("failed to process: {}", codemodInvocationContext.path(), e3);
            throw e3;
        }
    }

    private boolean estimatedToExceedContextWindow(CodemodInvocationContext codemodInvocationContext) {
        ChatRequestUserMessage chatRequestUserMessage = new ChatRequestUserMessage(codemodInvocationContext.contents());
        for (Model model : List.of(this.categorizationModel, this.codeChangingModel)) {
            if (model.tokens(List.of(getSystemMessage().getContent(), chatRequestUserMessage.getContent().toString())) + 300 > model.contextWindow()) {
                return true;
            }
        }
        return false;
    }

    protected CodemodChange createCodemodChange(Result result, int i, String str) {
        return CodemodChange.from(i, str);
    }

    protected abstract String getThreatPrompt();

    private CategorizeResponse categorize(FileDescription fileDescription, Result result) throws IOException {
        return getCategorizationResponse(getSystemMessage(), getCategorizationUserMessage(fileDescription, result));
    }

    private CodeChangeResponse changeCode(FileDescription fileDescription, Result result) throws IOException {
        return getCodeChangeResponse(getSystemMessage(), getChangeCodeMessage(fileDescription, result));
    }

    private CategorizeResponse getCategorizationResponse(ChatRequestMessage chatRequestMessage, ChatRequestMessage chatRequestMessage2) throws IOException {
        return (CategorizeResponse) this.openAI.getResponseForPrompt(List.of(chatRequestMessage, chatRequestMessage2), this.categorizationModel, CategorizeResponse.class);
    }

    private CodeChangeResponse getCodeChangeResponse(ChatRequestMessage chatRequestMessage, ChatRequestMessage chatRequestMessage2) throws IOException {
        return (CodeChangeResponse) this.openAI.getResponseForPrompt(List.of(chatRequestMessage, chatRequestMessage2), this.codeChangingModel, CodeChangeResponse.class);
    }

    private ChatRequestSystemMessage getSystemMessage() {
        return new ChatRequestSystemMessage(SYSTEM_MESSAGE_TEMPLATE.formatted(getThreatPrompt().strip()).strip());
    }

    private ChatRequestMessage getCategorizationUserMessage(FileDescription fileDescription, Result result) {
        Region region = ((Location) result.getLocations().get(0)).getPhysicalLocation().getRegion();
        int intValue = region.getStartLine().intValue();
        Integer startColumn = region.getStartColumn();
        String formatOutcomeDescriptions = formatOutcomeDescriptions(false);
        Object[] objArr = new Object[5];
        objArr[0] = String.valueOf(intValue);
        objArr[1] = startColumn != null ? String.valueOf(startColumn) : "(unknown)";
        objArr[2] = formatOutcomeDescriptions;
        objArr[3] = fileDescription.getFileName();
        objArr[4] = fileDescription.formatLinesWithLineNumbers();
        return new ChatRequestSystemMessage(CATEGORIZE_CODE_USER_MESSAGE_TEMPLATE.formatted(objArr).strip());
    }

    private String formatOutcomeDescriptions(boolean z) {
        String str = "============\nOutcome: %s\nDescription: %s\nCode Changes Required: YES\nCode Change Directions For Outcome: %s\n";
        String str2 = "============\nOutcome: %s\nDescription: %s\nCode Changes Required: NO\n";
        Function function = lLMRemediationOutcome -> {
            return str.formatted(lLMRemediationOutcome.key(), lLMRemediationOutcome.description(), lLMRemediationOutcome.fix());
        };
        Function function2 = lLMRemediationOutcome2 -> {
            return str2.formatted(lLMRemediationOutcome2.key(), lLMRemediationOutcome2.description());
        };
        return ((String) this.remediationOutcomes.stream().map(lLMRemediationOutcome3 -> {
            return z ? (String) function.apply(lLMRemediationOutcome3) : (String) function2.apply(lLMRemediationOutcome3);
        }).collect(Collectors.joining("\n"))) + "\n============";
    }

    private ChatRequestMessage getChangeCodeMessage(FileDescription fileDescription, Result result) {
        Region region = ((Location) result.getLocations().get(0)).getPhysicalLocation().getRegion();
        return new ChatRequestUserMessage(CHANGE_CODE_USER_MESSAGE_TEMPLATE.formatted("  Line " + region.getStartLine() + ", column " + region.getStartColumn(), formatOutcomeDescriptions(true), fileDescription.getFileName(), fileDescription.formatLinesWithLineNumbers()).strip());
    }
}
