package org.mvnsearch.chatgpt.spring.service;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.mvnsearch.chatgpt.model.ChatCompletionRequest;
import org.mvnsearch.chatgpt.model.ChatCompletionResponse;
import org.mvnsearch.chatgpt.model.ChatFunction;
import org.mvnsearch.chatgpt.model.ChatMessage;
import org.mvnsearch.chatgpt.model.FunctionCall;
import org.mvnsearch.chatgpt.model.function.ChatGPTJavaFunction;
import org.mvnsearch.chatgpt.model.function.GPTFunctionUtils;
import org.mvnsearch.chatgpt.model.function.GPTFunctionsStub;
import org.mvnsearch.chatgpt.spring.http.OpenAIChatAPI;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImpl.class */
public class ChatGPTServiceImpl implements ChatGPTService {
    private static final Logger log = LoggerFactory.getLogger(ChatGPTServiceImpl.class);
    private final OpenAIChatAPI openAIChatAPI;
    private final Map<String, ChatGPTJavaFunction> allJsonSchemaFunctions = new HashMap();
    private final Map<String, ChatFunction> allChatFunctions = new HashMap();

    public ChatGPTServiceImpl(OpenAIChatAPI openAIChatAPI, List<GPTFunctionsStub> list) throws Exception {
        this.openAIChatAPI = openAIChatAPI;
        for (GPTFunctionsStub gPTFunctionsStub : list) {
            Map<String, ChatGPTJavaFunction> extractFunctions = GPTFunctionUtils.extractFunctions(gPTFunctionsStub.getClass());
            if (!extractFunctions.isEmpty()) {
                for (Map.Entry<String, ChatGPTJavaFunction> entry : extractFunctions.entrySet()) {
                    ChatGPTJavaFunction value = entry.getValue();
                    value.setTarget(gPTFunctionsStub);
                    this.allJsonSchemaFunctions.put(entry.getKey(), value);
                    this.allChatFunctions.put(entry.getKey(), value.toChatFunction());
                }
            }
        }
        log.info("ChatGPTService initialized with {} functions", Integer.valueOf(this.allJsonSchemaFunctions.size()));
    }

    @Override // org.mvnsearch.chatgpt.spring.service.ChatGPTService
    public Mono<ChatCompletionResponse> chat(ChatCompletionRequest chatCompletionRequest) {
        injectFunctions(chatCompletionRequest);
        chatCompletionRequest.setStream(null);
        return !(chatCompletionRequest.getFunctions() != null) ? this.openAIChatAPI.chat(chatCompletionRequest) : this.openAIChatAPI.chat(chatCompletionRequest).doOnNext(chatCompletionResponse -> {
            Iterator<ChatMessage> it = chatCompletionResponse.getReply().iterator();
            while (it.hasNext()) {
                injectFunctionCallLambda(it.next());
            }
        });
    }

    @Override // org.mvnsearch.chatgpt.spring.service.ChatGPTService
    public Flux<ChatCompletionResponse> stream(ChatCompletionRequest chatCompletionRequest) {
        chatCompletionRequest.setStream(true);
        injectFunctions(chatCompletionRequest);
        return !(chatCompletionRequest.getFunctions() != null) ? this.openAIChatAPI.stream(chatCompletionRequest).onErrorContinue((th, obj) -> {
        }) : this.openAIChatAPI.stream(chatCompletionRequest).onErrorContinue((th2, obj2) -> {
        }).doOnNext(chatCompletionResponse -> {
            Iterator<ChatMessage> it = chatCompletionResponse.getReply().iterator();
            while (it.hasNext()) {
                injectFunctionCallLambda(it.next());
            }
        });
    }

    private void injectFunctions(ChatCompletionRequest chatCompletionRequest) {
        List<String> functionNames = chatCompletionRequest.getFunctionNames();
        if (functionNames == null || functionNames.isEmpty()) {
            return;
        }
        Iterator<String> it = functionNames.iterator();
        while (it.hasNext()) {
            ChatFunction chatFunction = this.allChatFunctions.get(it.next());
            if (chatFunction != null) {
                chatCompletionRequest.addFunction(chatFunction);
            }
        }
    }

    private void injectFunctionCallLambda(ChatMessage chatMessage) {
        FunctionCall functionCall = chatMessage.getFunctionCall();
        if (functionCall != null) {
            ChatGPTJavaFunction chatGPTJavaFunction = this.allJsonSchemaFunctions.get(functionCall.getName());
            if (chatGPTJavaFunction != null) {
                functionCall.setFunctionStub(() -> {
                    return chatGPTJavaFunction.call(functionCall.getArguments());
                });
            }
        }
    }
}
