/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j.agentexecutor;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.action.AsyncEdgeAction;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.agentexecutor.Agent;
import org.bsc.langgraph4j.agentexecutor.AgentAction;
import org.bsc.langgraph4j.agentexecutor.AgentFinish;
import org.bsc.langgraph4j.agentexecutor.AgentOutcome;
import org.bsc.langgraph4j.agentexecutor.IntermediateStep;
import org.bsc.langgraph4j.agentexecutor.serializer.json.JSONStateSerializer;
import org.bsc.langgraph4j.agentexecutor.serializer.std.STDStateSerializer;
import org.bsc.langgraph4j.langchain4j.generators.LLMStreamingGenerator;
import org.bsc.langgraph4j.langchain4j.tool.ToolNode;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppenderChannel;
import org.bsc.langgraph4j.state.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AgentExecutor {
    private static final Logger log = LoggerFactory.getLogger(AgentExecutor.class);

    public final GraphBuilder graphBuilder() {
        return new GraphBuilder();
    }

    Map<String, Object> callAgent(Agent agentRunnable, State state) {
        log.trace("callAgent");
        String input = state.input().orElseThrow(() -> new IllegalArgumentException("no input provided!"));
        List<IntermediateStep> intermediateSteps = state.intermediateSteps();
        Function<Response, Map> mapResult = response -> {
            if (response.finishReason() == FinishReason.TOOL_EXECUTION) {
                List toolExecutionRequests = ((AiMessage)response.content()).toolExecutionRequests();
                AgentAction action = new AgentAction((ToolExecutionRequest)toolExecutionRequests.get(0), "");
                return Map.of("agent_outcome", new AgentOutcome(action, null));
            }
            String result = ((AiMessage)response.content()).text();
            AgentFinish finish = new AgentFinish(Map.of("returnValues", result), result);
            return Map.of("agent_outcome", new AgentOutcome(null, finish));
        };
        if (agentRunnable.isStreaming()) {
            LLMStreamingGenerator generator = LLMStreamingGenerator.builder().mapResult(mapResult).startingNode("agent").startingState((AgentState)state).build();
            agentRunnable.execute(input, intermediateSteps, (StreamingResponseHandler<AiMessage>)generator.handler());
            return Map.of("agent_outcome", generator);
        }
        Response<AiMessage> response2 = agentRunnable.execute(input, intermediateSteps);
        return mapResult.apply(response2);
    }

    Map<String, Object> executeTools(ToolNode toolNode, State state) {
        log.trace("executeTools");
        AgentOutcome agentOutcome = state.agentOutcome().orElseThrow(() -> new IllegalArgumentException("no agentOutcome provided!"));
        ToolExecutionRequest toolExecutionRequest = Optional.ofNullable(agentOutcome.action()).map(AgentAction::toolExecutionRequest).orElseThrow(() -> new IllegalStateException("no action provided!"));
        String result = toolNode.execute(toolExecutionRequest).map(ToolExecutionResultMessage::text).orElseThrow(() -> new IllegalStateException("no tool found for: " + toolExecutionRequest.name()));
        return Map.of("intermediate_steps", new IntermediateStep(agentOutcome.action(), result));
    }

    String shouldContinue(State state) {
        return state.agentOutcome().map(AgentOutcome::finish).map(finish -> "end").orElse("continue");
    }

    public class GraphBuilder {
        private StreamingChatLanguageModel streamingChatLanguageModel;
        private ChatLanguageModel chatLanguageModel;
        private List<Object> objectsWithTools;
        private StateSerializer<State> stateSerializer;

        public GraphBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
            this.chatLanguageModel = chatLanguageModel;
            return this;
        }

        public GraphBuilder chatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
            this.streamingChatLanguageModel = streamingChatLanguageModel;
            return this;
        }

        public GraphBuilder objectsWithTools(List<Object> objectsWithTools) {
            this.objectsWithTools = objectsWithTools;
            return this;
        }

        public GraphBuilder stateSerializer(StateSerializer<State> stateSerializer) {
            this.stateSerializer = stateSerializer;
            return this;
        }

        public StateGraph<State> build() throws GraphStateException {
            Objects.requireNonNull(this.objectsWithTools, "objectsWithTools is required!");
            if (this.streamingChatLanguageModel != null && this.chatLanguageModel != null) {
                throw new IllegalArgumentException("chatLanguageModel and streamingChatLanguageModel are mutually exclusive!");
            }
            if (this.streamingChatLanguageModel == null && this.chatLanguageModel == null) {
                throw new IllegalArgumentException("a chatLanguageModel or streamingChatLanguageModel is required!");
            }
            ToolNode toolNode = ToolNode.of(this.objectsWithTools);
            List toolSpecifications = toolNode.toolSpecifications();
            Agent agentRunnable = Agent.builder().chatLanguageModel(this.chatLanguageModel).streamingChatLanguageModel(this.streamingChatLanguageModel).tools(toolSpecifications).build();
            if (this.stateSerializer == null) {
                this.stateSerializer = Serializers.STD.object();
            }
            return new StateGraph(State.SCHEMA, this.stateSerializer).addEdge(StateGraph.START, "agent").addNode("agent", AsyncNodeAction.node_async(state -> AgentExecutor.this.callAgent(agentRunnable, (State)state))).addNode("action", AsyncNodeAction.node_async(state -> AgentExecutor.this.executeTools(toolNode, (State)state))).addConditionalEdges("agent", AsyncEdgeAction.edge_async(AgentExecutor.this::shouldContinue), Map.of("continue", "action", "end", StateGraph.END)).addEdge("action", "agent");
        }
    }

    public static class State
    extends AgentState {
        static Map<String, Channel<?>> SCHEMA = Map.of("intermediate_steps", AppenderChannel.of(ArrayList::new));

        public State(Map<String, Object> initData) {
            super(initData);
        }

        Optional<String> input() {
            return this.value("input");
        }

        Optional<AgentOutcome> agentOutcome() {
            return this.value("agent_outcome");
        }

        List<IntermediateStep> intermediateSteps() {
            return this.value("intermediate_steps").orElseGet(ArrayList::new);
        }
    }

    public static enum Serializers {
        STD((StateSerializer<State>)new STDStateSerializer()),
        JSON((StateSerializer<State>)new JSONStateSerializer());

        private final StateSerializer<State> serializer;

        private Serializers(StateSerializer<State> serializer) {
            this.serializer = serializer;
        }

        public StateSerializer<State> object() {
            return this.serializer;
        }
    }
}

