package io.codemodder.plugins.llm;

import com.contrastsecurity.sarif.Location;
import com.contrastsecurity.sarif.Region;
import com.contrastsecurity.sarif.Result;
import com.github.difflib.DiffUtils;
import com.github.difflib.patch.AbstractDelta;
import com.github.difflib.patch.Patch;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatFunction;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.service.FunctionExecutor;
import io.codemodder.CodemodChange;
import io.codemodder.CodemodInvocationContext;
import io.codemodder.RuleSarif;
import io.codemodder.SarifPluginRawFileChanger;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/codemodder/plugins/llm/SarifToLLMForBinaryVerificationAndFixingCodemod.class */
public abstract class SarifToLLMForBinaryVerificationAndFixingCodemod extends SarifPluginRawFileChanger {
    private static final Logger logger = LoggerFactory.getLogger(SarifToLLMForBinaryVerificationAndFixingCodemod.class);
    private final OpenAIService openAI;
    private static final String SYSTEM_MESSAGE_TEMPLATE = "You are a security analyst bot. You are helping analyze Java code to assess its risk to a specific security threat.\n\n%s\n";
    private static final String ANALYZE_USER_MESSAGE_TEMPLATE = "A file with line numbers is provided below. Analyze it and save your threat analysis.\n\n--- %s\n%s\n";
    private static final String FIX_USER_MESSAGE_TEMPLATE = "A file with line numbers is provided below. Analyze it. If the risk is HIGH, use these rules to make the MINIMUM number of changes necessary to reduce the file's risk to LOW:\n- Each change MUST be syntactically correct.\n- DO NOT change the file's formatting or comments.\n%s\n\nCreate a diff patch for the changed file, using the unified format with a header. Include the diff patch and a summary of the changes with your threat analysis.\n\nSave your threat analysis.\n\n--- %s\n%s\n";

    protected SarifToLLMForBinaryVerificationAndFixingCodemod(RuleSarif ruleSarif, OpenAIService openAIService) {
        super(ruleSarif);
        this.openAI = (OpenAIService) Objects.requireNonNull(openAIService);
    }

    public List<CodemodChange> onFileFound(CodemodInvocationContext codemodInvocationContext, List<Result> list) {
        logger.debug("processing: {}", codemodInvocationContext.path());
        list.forEach(result -> {
            Region region = ((Location) result.getLocations().get(0)).getPhysicalLocation().getRegion();
            logger.debug("{}:{}", region.getStartLine(), region.getSnippet().getText());
        });
        try {
            FileDescription from = FileDescription.from(codemodInvocationContext.path());
            BinaryThreatAnalysis analyzeThreat = analyzeThreat(from, codemodInvocationContext, list);
            logger.debug("risk: {}", analyzeThreat.getRisk());
            logger.debug("analysis: {}", analyzeThreat.getAnalysis());
            if (analyzeThreat.getRisk() == BinaryThreatRisk.LOW) {
                return List.of();
            }
            BinaryThreatAnalysisAndFix fixThreat = fixThreat(from, codemodInvocationContext, list);
            logger.debug("risk: {}", fixThreat.getRisk());
            logger.debug("analysis: {}", fixThreat.getAnalysis());
            logger.debug("fix: {}", fixThreat.getFix());
            logger.debug("fix description: {}", fixThreat.getFixDescription());
            if (fixThreat.getRisk() == BinaryThreatRisk.LOW) {
                return List.of();
            }
            if (fixThreat.getFix() == null || fixThreat.getFix().length() == 0) {
                logger.info("unable to fix: {}", codemodInvocationContext.path());
                return List.of();
            }
            List<String> applyDiff = LLMDiffs.applyDiff(from.getLines(), fixThreat.getFix());
            Patch<String> diff = DiffUtils.diff(from.getLines(), applyDiff);
            if (diff.getDeltas().size() == 0 || !isPatchExpected(diff)) {
                logger.error("unexpected patch: {}", diff);
                return List.of();
            }
            try {
                Files.writeString(codemodInvocationContext.path(), String.join(from.getLineSeparator(), applyDiff), from.getCharset(), new OpenOption[0]);
                return List.of(CodemodChange.from(((AbstractDelta) diff.getDeltas().get(0)).getSource().getPosition() + 1, fixThreat.getFixDescription()));
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        } catch (Exception e2) {
            logger.error("failed to process: {}", codemodInvocationContext.path(), e2);
            throw e2;
        }
    }

    protected abstract String getThreatPrompt(CodemodInvocationContext codemodInvocationContext, List<Result> list);

    protected abstract String getFixPrompt();

    protected abstract boolean isPatchExpected(Patch<String> patch);

    private BinaryThreatAnalysis analyzeThreat(FileDescription fileDescription, CodemodInvocationContext codemodInvocationContext, List<Result> list) {
        ChatMessage systemMessage = getSystemMessage(codemodInvocationContext, list);
        ChatMessage analyzeUserMessage = getAnalyzeUserMessage(fileDescription);
        int countTokens = Tokens.countTokens(List.of(systemMessage, analyzeUserMessage));
        if (countTokens > 3796) {
            return new BinaryThreatAnalysis("Ignoring file: estimated prompt token count (" + countTokens + ") is too high.", BinaryThreatRisk.LOW);
        }
        logger.debug("estimated prompt token count: {}", Integer.valueOf(countTokens));
        return (BinaryThreatAnalysis) getLLMResponse("gpt-3.5-turbo-0613", Double.valueOf(0.2d), systemMessage, analyzeUserMessage, BinaryThreatAnalysis.class);
    }

    private BinaryThreatAnalysisAndFix fixThreat(FileDescription fileDescription, CodemodInvocationContext codemodInvocationContext, List<Result> list) {
        return (BinaryThreatAnalysisAndFix) getLLMResponse("gpt-4-0613", Double.valueOf(0.0d), getSystemMessage(codemodInvocationContext, list), getFixUserMessage(fileDescription), BinaryThreatAnalysisAndFix.class);
    }

    private <T> T getLLMResponse(String str, Double d, ChatMessage chatMessage, ChatMessage chatMessage2, Class<T> cls) {
        ChatFunction build = ChatFunction.builder().name("save_analysis").description("Saves a security threat analysis.").executor(cls, obj -> {
            return obj;
        }).build();
        FunctionExecutor functionExecutor = new FunctionExecutor(Collections.singletonList(build));
        ChatCompletionResult createChatCompletion = this.openAI.createChatCompletion(ChatCompletionRequest.builder().model(str).messages(List.of(chatMessage, chatMessage2)).functions(functionExecutor.getFunctions()).functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of(build.getName())).temperature(d).build());
        logger.debug(createChatCompletion.getUsage().toString());
        return (T) functionExecutor.execute(((ChatCompletionChoice) createChatCompletion.getChoices().get(0)).getMessage().getFunctionCall());
    }

    private ChatMessage getSystemMessage(CodemodInvocationContext codemodInvocationContext, List<Result> list) {
        return new ChatMessage(ChatMessageRole.SYSTEM.value(), SYSTEM_MESSAGE_TEMPLATE.formatted(getThreatPrompt(codemodInvocationContext, list).strip()).strip());
    }

    private ChatMessage getAnalyzeUserMessage(FileDescription fileDescription) {
        return new ChatMessage(ChatMessageRole.SYSTEM.value(), ANALYZE_USER_MESSAGE_TEMPLATE.formatted(fileDescription.getFileName(), fileDescription.formatLinesWithLineNumbers()).strip());
    }

    private ChatMessage getFixUserMessage(FileDescription fileDescription) {
        return new ChatMessage(ChatMessageRole.USER.value(), FIX_USER_MESSAGE_TEMPLATE.formatted(getFixPrompt().strip(), fileDescription.getFileName(), fileDescription.formatLinesWithLineNumbers()).strip());
    }
}
