package chat.octet.model.components.prompt;

import chat.octet.model.beans.ChatMessage;
import chat.octet.model.exceptions.ModelException;
import chat.octet.model.functions.Function;
import chat.octet.model.functions.FunctionConstants;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Resources;
import com.hubspot.jinjava.Jinjava;
import com.hubspot.jinjava.JinjavaConfig;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:chat/octet/model/components/prompt/DefaultChatTemplateFormatter.class */
public class DefaultChatTemplateFormatter implements ChatTemplateFormatter {
    private static final Logger log = LoggerFactory.getLogger(DefaultChatTemplateFormatter.class);
    private final Jinjava jinJava;
    private String chatTemplate;

    public DefaultChatTemplateFormatter(String str, String str2) {
        Preconditions.checkNotNull(str, "Model type name cannot be null");
        String str3 = "chat_templates/" + str.toLowerCase() + ".tmpl";
        this.jinJava = new Jinjava(JinjavaConfig.newBuilder().withTrimBlocks(true).build());
        try {
            this.chatTemplate = Resources.toString(Resources.getResource(str3), Charsets.UTF_8);
            log.info("Loaded chat template from local resource: {}", str3);
        } catch (Exception e) {
            if (StringUtils.isBlank(str2)) {
                throw new ModelException("Failed to load local chat template: " + str3, e);
            }
            this.chatTemplate = str2;
            log.warn("Failed to load local chat template, use default template.");
        }
        log.debug("Created a new chat formatter, chat template: {}", this.chatTemplate);
    }

    public DefaultChatTemplateFormatter(String str) {
        this(str, null);
    }

    @Override // chat.octet.model.components.prompt.ChatTemplateFormatter
    public String format(List<ChatMessage> list, List<Function> list2, boolean z, Map<String, Object> map) {
        Preconditions.checkNotNull(list, "Chat messages cannot be null");
        if (!StringUtils.contains(this.chatTemplate, "messages")) {
            throw new IllegalArgumentException("The chat template must contain the 'messages' variable");
        }
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        if (map != null && !map.isEmpty()) {
            newLinkedHashMap.putAll(map);
        }
        if (StringUtils.contains(this.chatTemplate, "add_generation_prompt")) {
            newLinkedHashMap.put("add_generation_prompt", Boolean.valueOf(z));
        }
        if (!StringUtils.contains(this.chatTemplate, "add_function_calls") || list2 == null || list2.isEmpty()) {
            newLinkedHashMap.put("messages", list);
        } else {
            newLinkedHashMap.put("add_function_calls", true);
            newLinkedHashMap.put("functions", list2);
            ArrayList newArrayList = Lists.newArrayList();
            for (ChatMessage chatMessage : list) {
                if (ChatMessage.ChatRole.SYSTEM == chatMessage.getRole()) {
                    newArrayList.add(chatMessage);
                } else if (ChatMessage.ChatRole.USER == chatMessage.getRole()) {
                    newLinkedHashMap.put(FunctionConstants.FUNCTION_TEMPLATE_ARGS_QUERY, chatMessage.getContent());
                    newArrayList.add(chatMessage);
                } else if (ChatMessage.ChatRole.ASSISTANT == chatMessage.getRole() && chatMessage.getToolCalls() != null) {
                    newLinkedHashMap.put("function_call", true);
                } else if (ChatMessage.ChatRole.FUNCTION == chatMessage.getRole()) {
                    newLinkedHashMap.put("function_feedback", true);
                    newLinkedHashMap.put(FunctionConstants.FUNCTION_TEMPLATE_ARGS_RESULT, Optional.ofNullable(chatMessage.getContent()).orElse(FunctionConstants.FUNCTION_TEMPLATE_ARGS_NONE));
                }
            }
            newLinkedHashMap.put("messages", newArrayList);
        }
        return this.jinJava.render(this.chatTemplate, newLinkedHashMap);
    }

    public String getChatTemplate() {
        return this.chatTemplate;
    }
}
