/*
 * Decompiled with CFR 0.152.
 */
package io.codemodder.plugins.llm;

import com.contrastsecurity.sarif.Location;
import com.contrastsecurity.sarif.Region;
import com.contrastsecurity.sarif.Result;
import com.github.difflib.DiffUtils;
import com.github.difflib.patch.AbstractDelta;
import com.github.difflib.patch.Patch;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatFunction;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.service.FunctionExecutor;
import io.codemodder.CodemodChange;
import io.codemodder.CodemodInvocationContext;
import io.codemodder.RuleSarif;
import io.codemodder.SarifPluginRawFileChanger;
import io.codemodder.plugins.llm.BinaryThreatAnalysis;
import io.codemodder.plugins.llm.BinaryThreatAnalysisAndFix;
import io.codemodder.plugins.llm.BinaryThreatRisk;
import io.codemodder.plugins.llm.FileDescription;
import io.codemodder.plugins.llm.LLMDiffs;
import io.codemodder.plugins.llm.OpenAIService;
import io.codemodder.plugins.llm.Tokens;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class SarifToLLMForBinaryVerificationAndFixingCodemod
extends SarifPluginRawFileChanger {
    private static final Logger logger = LoggerFactory.getLogger(SarifToLLMForBinaryVerificationAndFixingCodemod.class);
    private final OpenAIService openAI;
    private static final String SYSTEM_MESSAGE_TEMPLATE = "You are a security analyst bot. You are helping analyze Java code to assess its risk to a specific security threat.\n\n%s\n";
    private static final String ANALYZE_USER_MESSAGE_TEMPLATE = "A file with line numbers is provided below. Analyze it and save your threat analysis.\n\n--- %s\n%s\n";
    private static final String FIX_USER_MESSAGE_TEMPLATE = "A file with line numbers is provided below. Analyze it. If the risk is HIGH, use these rules to make the MINIMUM number of changes necessary to reduce the file's risk to LOW:\n- Each change MUST be syntactically correct.\n- DO NOT change the file's formatting or comments.\n%s\n\nCreate a diff patch for the changed file, using the unified format with a header. Include the diff patch and a summary of the changes with your threat analysis.\n\nSave your threat analysis.\n\n--- %s\n%s\n";

    protected SarifToLLMForBinaryVerificationAndFixingCodemod(RuleSarif sarif, OpenAIService openAI) {
        super(sarif);
        this.openAI = Objects.requireNonNull(openAI);
    }

    public List<CodemodChange> onFileFound(CodemodInvocationContext context, List<Result> results) {
        logger.debug("processing: {}", (Object)context.path());
        results.forEach(result -> {
            Region region = ((Location)result.getLocations().get(0)).getPhysicalLocation().getRegion();
            logger.debug("{}:{}", (Object)region.getStartLine(), (Object)region.getSnippet().getText());
        });
        try {
            FileDescription file = FileDescription.from(context.path());
            BinaryThreatAnalysis analysis = this.analyzeThreat(file, context, results);
            logger.debug("risk: {}", (Object)analysis.getRisk());
            logger.debug("analysis: {}", (Object)analysis.getAnalysis());
            if (analysis.getRisk() == BinaryThreatRisk.LOW) {
                return List.of();
            }
            BinaryThreatAnalysisAndFix fix = this.fixThreat(file, context, results);
            logger.debug("risk: {}", (Object)fix.getRisk());
            logger.debug("analysis: {}", (Object)fix.getAnalysis());
            logger.debug("fix: {}", (Object)fix.getFix());
            logger.debug("fix description: {}", (Object)fix.getFixDescription());
            if (fix.getRisk() == BinaryThreatRisk.LOW) {
                return List.of();
            }
            if (fix.getFix() == null || fix.getFix().length() == 0) {
                logger.info("unable to fix: {}", (Object)context.path());
                return List.of();
            }
            List<String> fixedLines = LLMDiffs.applyDiff(file.getLines(), fix.getFix());
            Patch patch = DiffUtils.diff(file.getLines(), fixedLines);
            if (patch.getDeltas().size() == 0 || !this.isPatchExpected((Patch<String>)patch)) {
                logger.error("unexpected patch: {}", (Object)patch);
                return List.of();
            }
            try {
                String fixedFile = String.join((CharSequence)file.getLineSeparator(), fixedLines);
                Files.writeString(context.path(), (CharSequence)fixedFile, file.getCharset(), new OpenOption[0]);
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
            int line = ((AbstractDelta)patch.getDeltas().get(0)).getSource().getPosition() + 1;
            return List.of(CodemodChange.from((int)line, (String)fix.getFixDescription()));
        }
        catch (Exception e) {
            logger.error("failed to process: {}", (Object)context.path(), (Object)e);
            throw e;
        }
    }

    protected abstract String getThreatPrompt(CodemodInvocationContext var1, List<Result> var2);

    protected abstract String getFixPrompt();

    protected abstract boolean isPatchExpected(Patch<String> var1);

    private BinaryThreatAnalysis analyzeThreat(FileDescription file, CodemodInvocationContext context, List<Result> results) {
        ChatMessage userMessage;
        ChatMessage systemMessage = this.getSystemMessage(context, results);
        int tokenCount = Tokens.countTokens(List.of(systemMessage, userMessage = this.getAnalyzeUserMessage(file)));
        if (tokenCount > 3796) {
            return new BinaryThreatAnalysis("Ignoring file: estimated prompt token count (" + tokenCount + ") is too high.", BinaryThreatRisk.LOW);
        }
        logger.debug("estimated prompt token count: {}", (Object)tokenCount);
        return this.getLLMResponse("gpt-3.5-turbo-0613", 0.2, systemMessage, userMessage, BinaryThreatAnalysis.class);
    }

    private BinaryThreatAnalysisAndFix fixThreat(FileDescription file, CodemodInvocationContext context, List<Result> results) {
        return this.getLLMResponse("gpt-4-0613", 0.0, this.getSystemMessage(context, results), this.getFixUserMessage(file), BinaryThreatAnalysisAndFix.class);
    }

    private <T> T getLLMResponse(String model, Double temperature, ChatMessage systemMessage, ChatMessage userMessage, Class<T> responseClass) {
        ChatFunction function = ChatFunction.builder().name("save_analysis").description("Saves a security threat analysis.").executor(responseClass, c -> c).build();
        FunctionExecutor functionExecutor = new FunctionExecutor(Collections.singletonList(function));
        ChatCompletionRequest request = ChatCompletionRequest.builder().model(model).messages(List.of(systemMessage, userMessage)).functions(functionExecutor.getFunctions()).functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of((String)function.getName())).temperature(temperature).build();
        ChatCompletionResult result = this.openAI.createChatCompletion(request);
        logger.debug(result.getUsage().toString());
        ChatMessage response = ((ChatCompletionChoice)result.getChoices().get(0)).getMessage();
        return (T)functionExecutor.execute(response.getFunctionCall());
    }

    private ChatMessage getSystemMessage(CodemodInvocationContext context, List<Result> results) {
        String threatPrompt = this.getThreatPrompt(context, results);
        return new ChatMessage(ChatMessageRole.SYSTEM.value(), SYSTEM_MESSAGE_TEMPLATE.formatted(threatPrompt.strip()).strip());
    }

    private ChatMessage getAnalyzeUserMessage(FileDescription file) {
        return new ChatMessage(ChatMessageRole.SYSTEM.value(), ANALYZE_USER_MESSAGE_TEMPLATE.formatted(file.getFileName(), file.formatLinesWithLineNumbers()).strip());
    }

    private ChatMessage getFixUserMessage(FileDescription file) {
        return new ChatMessage(ChatMessageRole.USER.value(), FIX_USER_MESSAGE_TEMPLATE.formatted(this.getFixPrompt().strip(), file.getFileName(), file.formatLinesWithLineNumbers()).strip());
    }
}

