package chat.octet.model.utils;

import chat.octet.model.beans.ChatMessage;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.hubspot.jinjava.Jinjava;
import java.util.Arrays;
import java.util.LinkedHashMap;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:chat/octet/model/utils/ChatFormatter.class */
public class ChatFormatter {
    private static final Logger log = LoggerFactory.getLogger(ChatFormatter.class);
    public static final String DEFAULT_COMMON_SYSTEM = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.";
    public static final String CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
    public static final String CHATML_BOS_TOKEN = "<s>";
    public static final String CHATML_EOS_TOKEN = "<|im_end|>";
    private final Jinjava jinJava;
    private final String template;
    private final String bos;
    private final String eos;

    public ChatFormatter(String str, String str2, String str3) {
        this.jinJava = new Jinjava();
        this.template = str;
        this.bos = str2;
        this.eos = str3;
    }

    public ChatFormatter(String str) {
        this(str, "", "");
    }

    public ChatFormatter() {
        this(CHATML_CHAT_TEMPLATE, CHATML_BOS_TOKEN, CHATML_EOS_TOKEN);
    }

    public String format(boolean z, ChatMessage... chatMessageArr) {
        Preconditions.checkNotNull(chatMessageArr, "Chat messages cannot be null");
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        newLinkedHashMap.put("messages", Arrays.asList(chatMessageArr));
        if (StringUtils.contains(this.template, "bos_token")) {
            newLinkedHashMap.put("bos_token", this.bos);
        }
        if (StringUtils.contains(this.template, "eos_token")) {
            newLinkedHashMap.put("eos_token", this.eos);
        }
        if (StringUtils.contains(this.template, "add_generation_prompt")) {
            newLinkedHashMap.put("add_generation_prompt", Boolean.valueOf(z));
        }
        return this.jinJava.render(this.template, newLinkedHashMap);
    }

    public String format(ChatMessage... chatMessageArr) {
        return format(true, chatMessageArr);
    }

    public String format(String str, String str2) {
        Preconditions.checkNotNull(str2, "User question cannot be null");
        return StringUtils.isNotBlank(str) ? format(ChatMessage.toSystem(str), ChatMessage.toUser(str2)) : format(ChatMessage.toUser(str2));
    }

    public String format(String str) {
        return format((String) null, str);
    }

    public String getTemplate() {
        return this.template;
    }

    public String getBos() {
        return this.bos;
    }

    public String getEos() {
        return this.eos;
    }
}
