package org.bsc.langgraph4j.agentexecutor;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.service.tool.ToolExecutor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.action.AsyncEdgeAction;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.agentexecutor.serializer.json.JSONStateSerializer;
import org.bsc.langgraph4j.agentexecutor.serializer.std.STDStateSerializer;
import org.bsc.langgraph4j.langchain4j.generators.LLMStreamingGenerator;
import org.bsc.langgraph4j.langchain4j.tool.ToolNode;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppenderChannel;
import org.bsc.langgraph4j.state.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/bsc/langgraph4j/agentexecutor/AgentExecutor.class */
public class AgentExecutor {
    private static final Logger log = LoggerFactory.getLogger(AgentExecutor.class);

    /* loaded from: input_file:org/bsc/langgraph4j/agentexecutor/AgentExecutor$GraphBuilder.class */
    public class GraphBuilder {
        private StreamingChatLanguageModel streamingChatLanguageModel;
        private ChatLanguageModel chatLanguageModel;
        private ToolNode.Builder toolNodeBuilder = ToolNode.builder();
        private StateSerializer<State> stateSerializer;

        public GraphBuilder() {
        }

        public GraphBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
            this.chatLanguageModel = chatLanguageModel;
            return this;
        }

        public GraphBuilder chatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) {
            this.streamingChatLanguageModel = streamingChatLanguageModel;
            return this;
        }

        @Deprecated
        public GraphBuilder objectsWithTools(List<Object> list) {
            list.forEach(obj -> {
                this.toolNodeBuilder.specification(obj);
            });
            return this;
        }

        public GraphBuilder toolSpecification(Object obj) {
            this.toolNodeBuilder.specification(obj);
            return this;
        }

        public GraphBuilder toolSpecification(ToolSpecification toolSpecification, ToolExecutor toolExecutor) {
            this.toolNodeBuilder.specification(toolSpecification, toolExecutor);
            return this;
        }

        public GraphBuilder toolSpecification(ToolNode.Specification specification) {
            this.toolNodeBuilder.specification(specification);
            return this;
        }

        public GraphBuilder stateSerializer(StateSerializer<State> stateSerializer) {
            this.stateSerializer = stateSerializer;
            return this;
        }

        public StateGraph<State> build() throws GraphStateException {
            if (this.streamingChatLanguageModel != null && this.chatLanguageModel != null) {
                throw new IllegalArgumentException("chatLanguageModel and streamingChatLanguageModel are mutually exclusive!");
            }
            if (this.streamingChatLanguageModel == null && this.chatLanguageModel == null) {
                throw new IllegalArgumentException("a chatLanguageModel or streamingChatLanguageModel is required!");
            }
            ToolNode build = this.toolNodeBuilder.build();
            Agent build2 = Agent.builder().chatLanguageModel(this.chatLanguageModel).streamingChatLanguageModel(this.streamingChatLanguageModel).tools(build.toolSpecifications()).build();
            if (this.stateSerializer == null) {
                this.stateSerializer = Serializers.STD.object();
            }
            StateGraph addNode = new StateGraph(State.SCHEMA, this.stateSerializer).addEdge(StateGraph.START, "agent").addNode("agent", AsyncNodeAction.node_async(state -> {
                return AgentExecutor.this.callAgent(build2, state);
            })).addNode("action", AsyncNodeAction.node_async(state2 -> {
                return AgentExecutor.this.executeTools(build, state2);
            }));
            AgentExecutor agentExecutor = AgentExecutor.this;
            return addNode.addConditionalEdges("agent", AsyncEdgeAction.edge_async(agentExecutor::shouldContinue), Map.of("continue", "action", "end", StateGraph.END)).addEdge("action", "agent");
        }
    }

    /* loaded from: input_file:org/bsc/langgraph4j/agentexecutor/AgentExecutor$Serializers.class */
    public enum Serializers {
        STD(new STDStateSerializer()),
        JSON(new JSONStateSerializer());

        private final StateSerializer<State> serializer;

        Serializers(StateSerializer stateSerializer) {
            this.serializer = stateSerializer;
        }

        public StateSerializer<State> object() {
            return this.serializer;
        }
    }

    /* loaded from: input_file:org/bsc/langgraph4j/agentexecutor/AgentExecutor$State.class */
    public static class State extends AgentState {
        static Map<String, Channel<?>> SCHEMA = Map.of("intermediate_steps", AppenderChannel.of(ArrayList::new));

        public State(Map<String, Object> map) {
            super(map);
        }

        Optional<String> input() {
            return value("input");
        }

        Optional<AgentOutcome> agentOutcome() {
            return value("agent_outcome");
        }

        List<IntermediateStep> intermediateSteps() {
            return (List) value("intermediate_steps").orElseGet(ArrayList::new);
        }
    }

    public final GraphBuilder graphBuilder() {
        return new GraphBuilder();
    }

    Map<String, Object> callAgent(Agent agent, State state) {
        log.trace("callAgent");
        String orElseThrow = state.input().orElseThrow(() -> {
            return new IllegalArgumentException("no input provided!");
        });
        List<IntermediateStep> intermediateSteps = state.intermediateSteps();
        Function function = response -> {
            if (response.finishReason() == FinishReason.TOOL_EXECUTION) {
                return Map.of("agent_outcome", new AgentOutcome(new AgentAction((ToolExecutionRequest) ((AiMessage) response.content()).toolExecutionRequests().get(0), ""), null));
            }
            String text = ((AiMessage) response.content()).text();
            return Map.of("agent_outcome", new AgentOutcome(null, new AgentFinish(Map.of("returnValues", text), text)));
        };
        if (!agent.isStreaming()) {
            return (Map) function.apply(agent.execute(orElseThrow, intermediateSteps));
        }
        LLMStreamingGenerator build = LLMStreamingGenerator.builder().mapResult(function).startingNode("agent").startingState(state).build();
        agent.execute(orElseThrow, intermediateSteps, build.handler());
        return Map.of("agent_outcome", build);
    }

    Map<String, Object> executeTools(ToolNode toolNode, State state) {
        log.trace("executeTools");
        AgentOutcome orElseThrow = state.agentOutcome().orElseThrow(() -> {
            return new IllegalArgumentException("no agentOutcome provided!");
        });
        ToolExecutionRequest toolExecutionRequest = (ToolExecutionRequest) Optional.ofNullable(orElseThrow.action()).map((v0) -> {
            return v0.toolExecutionRequest();
        }).orElseThrow(() -> {
            return new IllegalStateException("no action provided!");
        });
        return Map.of("intermediate_steps", new IntermediateStep(orElseThrow.action(), (String) toolNode.execute(toolExecutionRequest).map((v0) -> {
            return v0.text();
        }).orElseThrow(() -> {
            return new IllegalStateException("no tool found for: " + toolExecutionRequest.name());
        })));
    }

    String shouldContinue(State state) {
        return (String) state.agentOutcome().map((v0) -> {
            return v0.finish();
        }).map(agentFinish -> {
            return "end";
        }).orElse("continue");
    }
}
