package ai.knowly.langtoch.llm.providers.openai;

import ai.knowly.langtoch.llm.base.chatmodel.BaseChatModel;
import ai.knowly.langtoch.llm.message.AssistantMessage;
import ai.knowly.langtoch.llm.message.BaseChatMessage;
import ai.knowly.langtoch.llm.message.Role;
import ai.knowly.langtoch.llm.message.SystemMessage;
import ai.knowly.langtoch.llm.message.UserMessage;
import com.google.common.collect.ImmutableList;
import com.google.common.flogger.FluentLogger;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import java.util.List;
import javax.inject.Inject;

/* loaded from: input_file:ai/knowly/langtoch/llm/providers/openai/OpenAIChat.class */
public class OpenAIChat extends BaseChatModel {
    private static final FluentLogger logger = FluentLogger.forEnclosingClass();
    private final int DEFAULT_MAX_TOKENS = 2048;
    private final String DEFAULT_MODEL = "gpt-3.5-turbo";
    private final OpenAiService openAiService;
    private final ChatCompletionRequest.ChatCompletionRequestBuilder completionRequest;

    @Inject
    OpenAIChat(OpenAiService openAiService) {
        this.DEFAULT_MAX_TOKENS = 2048;
        this.DEFAULT_MODEL = "gpt-3.5-turbo";
        this.completionRequest = ChatCompletionRequest.builder().maxTokens(2048).model("gpt-3.5-turbo");
        this.openAiService = openAiService;
    }

    public OpenAIChat(String str) {
        this.DEFAULT_MAX_TOKENS = 2048;
        this.DEFAULT_MODEL = "gpt-3.5-turbo";
        this.completionRequest = ChatCompletionRequest.builder().maxTokens(2048).model("gpt-3.5-turbo");
        Utils.logPartialApiKey(logger, str);
        this.openAiService = new OpenAiService(str);
    }

    public OpenAIChat() {
        this.DEFAULT_MAX_TOKENS = 2048;
        this.DEFAULT_MODEL = "gpt-3.5-turbo";
        this.completionRequest = ChatCompletionRequest.builder().maxTokens(2048).model("gpt-3.5-turbo");
        this.openAiService = new OpenAiService(Utils.getApiKeyFromEnv(logger));
    }

    private static ChatMessage toChatMessage(BaseChatMessage baseChatMessage) {
        ChatMessage chatMessage = new ChatMessage();
        chatMessage.setContent(baseChatMessage.getMessage());
        chatMessage.setRole(baseChatMessage.getRole().name().toLowerCase());
        return chatMessage;
    }

    public OpenAIChat setMaxTokens(int i) {
        this.completionRequest.maxTokens(Integer.valueOf(i));
        return this;
    }

    public OpenAIChat setModel(String str) {
        this.completionRequest.model(str);
        return this;
    }

    public OpenAIChat setTemperature(double d) {
        this.completionRequest.temperature(Double.valueOf(d));
        return this;
    }

    @Override // ai.knowly.langtoch.llm.base.chatmodel.BaseChatModel
    public BaseChatMessage run(List<BaseChatMessage> list) {
        ChatMessage message = ((ChatCompletionChoice) this.openAiService.createChatCompletion(this.completionRequest.messages((List) list.stream().map(OpenAIChat::toChatMessage).collect(ImmutableList.toImmutableList())).build()).getChoices().get(0)).getMessage();
        if (Role.USER.name().toLowerCase().equals(message.getRole())) {
            return UserMessage.builder().setMessage(message.getContent()).build();
        }
        if (Role.SYSTEM.name().toLowerCase().equals(message.getRole())) {
            return SystemMessage.builder().setMessage(message.getContent()).build();
        }
        if (Role.ASSISTANT.name().toLowerCase().equals(message.getRole())) {
            return AssistantMessage.builder().setMessage(message.getContent()).build();
        }
        throw new RuntimeException(String.format("Unknown role %s with message: %s ", message.getRole(), message.getContent()));
    }

    @Override // ai.knowly.langtoch.llm.base.chatmodel.BaseChatModel, ai.knowly.langtoch.llm.base.BaseModel
    public String run(String str) {
        ChatMessage chatMessage = new ChatMessage();
        chatMessage.setRole(Role.USER.name().toLowerCase());
        chatMessage.setContent(str);
        return ((ChatCompletionChoice) this.openAiService.createChatCompletion(this.completionRequest.messages(ImmutableList.of(chatMessage)).build()).getChoices().get(0)).getMessage().getContent();
    }
}
