package com.llmagent.azure;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import com.llmagent.azure.chat.AzureAiChatModelName;
import com.llmagent.azure.embedding.AzureAiEmbeddingModelName;
import com.llmagent.data.message.AiMessage;
import com.llmagent.data.message.ChatMessage;
import com.llmagent.data.message.ImageContent;
import com.llmagent.data.message.SystemMessage;
import com.llmagent.data.message.TextContent;
import com.llmagent.data.message.ToolMessage;
import com.llmagent.data.message.UserMessage;
import com.llmagent.exception.Exceptions;
import com.llmagent.llm.Tokenizer;
import com.llmagent.llm.tool.ToolParameters;
import com.llmagent.llm.tool.ToolRequest;
import com.llmagent.llm.tool.ToolSpecification;
import com.llmagent.util.JsonUtil;
import com.llmagent.util.StringUtil;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

/* loaded from: input_file:com/llmagent/azure/AzureAiTokenizer.class */
public class AzureAiTokenizer implements Tokenizer {
    private final String modelName;
    private final Optional<Encoding> encoding;

    public AzureAiTokenizer() {
        this(AzureAiChatModelName.GPT_3_5_TURBO.modelType());
    }

    public AzureAiTokenizer(AzureAiChatModelName azureAiChatModelName) {
        this(azureAiChatModelName.modelType());
    }

    public AzureAiTokenizer(AzureAiEmbeddingModelName azureAiEmbeddingModelName) {
        this(azureAiEmbeddingModelName.modelType());
    }

    public AzureAiTokenizer(String str) {
        this.modelName = str;
        this.encoding = Encodings.newLazyEncodingRegistry().getEncodingForModel(str);
    }

    public int estimateTokenCountInText(String str) {
        return this.encoding.orElseThrow(unknownModelException()).countTokensOrdinary(str);
    }

    public int estimateTokenCountInMessage(ChatMessage chatMessage) {
        int estimateTokenCountIn;
        int extraTokensPerMessage = 1 + extraTokensPerMessage();
        if (chatMessage instanceof SystemMessage) {
            estimateTokenCountIn = extraTokensPerMessage + estimateTokenCountIn((SystemMessage) chatMessage);
        } else if (chatMessage instanceof UserMessage) {
            estimateTokenCountIn = extraTokensPerMessage + estimateTokenCountIn((UserMessage) chatMessage);
        } else if (chatMessage instanceof AiMessage) {
            estimateTokenCountIn = extraTokensPerMessage + estimateTokenCountIn((AiMessage) chatMessage);
        } else {
            if (!(chatMessage instanceof ToolMessage)) {
                throw new IllegalArgumentException("Unknown message type: " + chatMessage);
            }
            estimateTokenCountIn = extraTokensPerMessage + estimateTokenCountIn((ToolMessage) chatMessage);
        }
        return estimateTokenCountIn;
    }

    private int estimateTokenCountIn(SystemMessage systemMessage) {
        return estimateTokenCountInText(systemMessage.content());
    }

    private int estimateTokenCountIn(UserMessage userMessage) {
        int i = 0;
        for (TextContent textContent : userMessage.contents()) {
            if (textContent instanceof TextContent) {
                i += estimateTokenCountInText(textContent.text());
            } else {
                if (!(textContent instanceof ImageContent)) {
                    throw Exceptions.illegalArgument("Unknown content type: " + textContent, new Object[0]);
                }
                i += 85;
            }
        }
        if (userMessage.name() != null && !this.modelName.equals(AzureAiChatModelName.GPT_4_VISION_PREVIEW.toString())) {
            i = i + extraTokensPerName() + estimateTokenCountInText(userMessage.name());
        }
        return i;
    }

    private int estimateTokenCountIn(AiMessage aiMessage) {
        int estimateTokenCountInText = aiMessage.content() != null ? 0 + estimateTokenCountInText(aiMessage.content()) : 0;
        if (aiMessage.toolRequests() != null) {
            int i = isOneOfLatestModels() ? estimateTokenCountInText + 6 : estimateTokenCountInText + 3;
            if (aiMessage.toolRequests().size() == 1) {
                ToolRequest toolRequest = (ToolRequest) aiMessage.toolRequests().get(0);
                estimateTokenCountInText = (i - 1) + (estimateTokenCountInText(toolRequest.name()) * 2) + estimateTokenCountInText(toolRequest.arguments());
            } else {
                estimateTokenCountInText = i + 15;
                for (ToolRequest toolRequest2 : aiMessage.toolRequests()) {
                    estimateTokenCountInText = estimateTokenCountInText + 7 + estimateTokenCountInText(toolRequest2.name());
                    for (Map.Entry entry : ((Map) JsonUtil.fromJson(toolRequest2.arguments(), Map.class)).entrySet()) {
                        estimateTokenCountInText = estimateTokenCountInText + 2 + estimateTokenCountInText(entry.getKey().toString()) + estimateTokenCountInText(entry.getValue().toString());
                    }
                }
            }
        }
        return estimateTokenCountInText;
    }

    private int estimateTokenCountIn(ToolMessage toolMessage) {
        return estimateTokenCountInText(toolMessage.content());
    }

    private int extraTokensPerMessage() {
        return this.modelName.equals(AzureAiChatModelName.GPT_3_5_TURBO_0301.modelName()) ? 4 : 3;
    }

    private int extraTokensPerName() {
        return this.modelName.equals(AzureAiChatModelName.GPT_3_5_TURBO_0301.toString()) ? -1 : 1;
    }

    public int estimateTokenCountInMessages(Iterable<ChatMessage> iterable) {
        int i = 3;
        Iterator<ChatMessage> it = iterable.iterator();
        while (it.hasNext()) {
            i += estimateTokenCountInMessage(it.next());
        }
        return i;
    }

    public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> iterable) {
        int i = 16;
        for (ToolSpecification toolSpecification : iterable) {
            int estimateTokenCountInText = i + 6 + estimateTokenCountInText(toolSpecification.name());
            if (toolSpecification.description() != null) {
                estimateTokenCountInText = estimateTokenCountInText + 2 + estimateTokenCountInText(toolSpecification.description());
            }
            i = estimateTokenCountInText + estimateTokenCountInToolParameters(toolSpecification.parameters());
        }
        return i;
    }

    private int estimateTokenCountInToolParameters(ToolParameters toolParameters) {
        if (toolParameters == null) {
            return 0;
        }
        Map properties = toolParameters.properties();
        int size = isOneOfLatestModels() ? 3 + (properties.size() - 1) : 3;
        for (String str : properties.keySet()) {
            size = (isOneOfLatestModels() ? size + 2 : size + 3) + estimateTokenCountInText(str);
            for (Map.Entry entry : ((Map) properties.get(str)).entrySet()) {
                if ("type".equals(entry.getKey())) {
                    if ("array".equals(entry.getValue()) && isOneOfLatestModels()) {
                        size++;
                    }
                } else if ("description".equals(entry.getKey())) {
                    size = size + 2 + estimateTokenCountInText(entry.getValue().toString());
                    if (isOneOfLatestModels() && toolParameters.required().contains(str)) {
                        size++;
                    }
                } else if ("enum".equals(entry.getKey())) {
                    size = isOneOfLatestModels() ? size - 2 : size - 3;
                    for (Object obj : (Object[]) entry.getValue()) {
                        size = size + 3 + estimateTokenCountInText(obj.toString());
                    }
                }
            }
        }
        return size;
    }

    public int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) {
        int estimateTokenCountInToolSpecifications = estimateTokenCountInToolSpecifications(Collections.singletonList(toolSpecification)) + 4 + estimateTokenCountInText(toolSpecification.name());
        if (isOneOfLatestModels()) {
            estimateTokenCountInToolSpecifications += 3;
        }
        return estimateTokenCountInToolSpecifications;
    }

    public List<Integer> encode(String str) {
        return this.encoding.orElseThrow(unknownModelException()).encodeOrdinary(str).boxed();
    }

    public List<Integer> encode(String str, int i) {
        return this.encoding.orElseThrow(unknownModelException()).encodeOrdinary(str, i).getTokens().boxed();
    }

    public String decode(List<Integer> list) {
        IntArrayList intArrayList = new IntArrayList();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            intArrayList.add(it.next().intValue());
        }
        return this.encoding.orElseThrow(unknownModelException()).decode(intArrayList);
    }

    private Supplier<IllegalArgumentException> unknownModelException() {
        return () -> {
            return Exceptions.illegalArgument("Model '%s' is unknown to jtokkit", new Object[]{this.modelName});
        };
    }

    public int estimateTokenCountInToolRequests(Iterable<ToolRequest> iterable) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        for (ToolRequest toolRequest : iterable) {
            i = i + 4 + estimateTokenCountInText(toolRequest.name()) + estimateTokenCountInText(toolRequest.arguments());
            int countArguments = countArguments(toolRequest.arguments());
            if (countArguments == 0) {
                i4++;
            } else {
                i3++;
            }
            i5 += countArguments;
            i2++;
        }
        if (this.modelName.equals(AzureAiChatModelName.GPT_3_5_TURBO_1106.toString()) || isOneOfLatestGpt4Models()) {
            i = i + 16 + (3 * i4) + i2;
            if (i5 > 0) {
                i = ((i - 1) - (2 * i5)) + (2 * i3) + i2;
            }
        }
        if (this.modelName.equals(AzureAiChatModelName.GPT_4_1106_PREVIEW.toString())) {
            i += 3;
            if (i2 > 1) {
                i = (((i + 18) + (15 * i2)) + i5) - (3 * i4);
            }
        }
        return i;
    }

    public int estimateTokenCountInForcefulToolRequest(ToolRequest toolRequest) {
        if (isOneOfLatestGpt4Models()) {
            if (countArguments(toolRequest.arguments()) == 0) {
                return 1;
            }
            return estimateTokenCountInText(toolRequest.arguments());
        }
        int estimateTokenCountInToolRequests = (estimateTokenCountInToolRequests(Collections.singletonList(toolRequest)) - 4) - estimateTokenCountInText(toolRequest.name());
        if (this.modelName.equals(AzureAiChatModelName.GPT_3_5_TURBO_1106.toString())) {
            int countArguments = countArguments(toolRequest.arguments());
            if (countArguments == 0) {
                return 1;
            }
            estimateTokenCountInToolRequests = (estimateTokenCountInToolRequests - 19) + (2 * countArguments);
        }
        return estimateTokenCountInToolRequests;
    }

    static int countArguments(String str) {
        if (StringUtil.isNullOrBlank(str)) {
            return 0;
        }
        return ((Map) JsonUtil.fromJson(str, Map.class)).size();
    }

    private boolean isOneOfLatestModels() {
        return isOneOfLatestGpt3Models() || isOneOfLatestGpt4Models();
    }

    private boolean isOneOfLatestGpt3Models() {
        return this.modelName.equals(AzureAiChatModelName.GPT_3_5_TURBO_1106.toString()) || this.modelName.equals(AzureAiChatModelName.GPT_3_5_TURBO.toString());
    }

    private boolean isOneOfLatestGpt4Models() {
        return this.modelName.equals(AzureAiChatModelName.GPT_4_TURBO.toString()) || this.modelName.equals(AzureAiChatModelName.GPT_4_1106_PREVIEW.toString()) || this.modelName.equals(AzureAiChatModelName.GPT_4_0125_PREVIEW.toString());
    }
}
