package dev.langchain4j.model.chat;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.input.Prompt;
import java.util.ArrayList;
import java.util.List;
import org.assertj.core.api.WithAssertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:dev/langchain4j/model/chat/TokenCountEstimatorTest.class */
class TokenCountEstimatorTest implements WithAssertions {

    /* loaded from: input_file:dev/langchain4j/model/chat/TokenCountEstimatorTest$WhitespaceSplitTokenCountEstimator.class */
    public static class WhitespaceSplitTokenCountEstimator implements TokenCountEstimator {
        public int estimateTokenCount(List<ChatMessage> list) {
            return list.stream().mapToInt(chatMessage -> {
                return chatMessage.text().split("\\s+").length;
            }).sum();
        }
    }

    TokenCountEstimatorTest() {
    }

    @Test
    public void test() {
        WhitespaceSplitTokenCountEstimator whitespaceSplitTokenCountEstimator = new WhitespaceSplitTokenCountEstimator();
        assertThat(whitespaceSplitTokenCountEstimator.estimateTokenCount("foo bar, baz")).isEqualTo(3);
        assertThat(whitespaceSplitTokenCountEstimator.estimateTokenCount(new UserMessage("foo bar, baz"))).isEqualTo(3);
        assertThat(whitespaceSplitTokenCountEstimator.estimateTokenCount(new Prompt("foo bar, baz"))).isEqualTo(3);
        assertThat(whitespaceSplitTokenCountEstimator.estimateTokenCount(TextSegment.from("foo bar, baz"))).isEqualTo(3);
        ArrayList arrayList = new ArrayList();
        arrayList.add(new UserMessage("Hello, world!"));
        arrayList.add(new AiMessage("How are you?"));
        assertThat(whitespaceSplitTokenCountEstimator.estimateTokenCount(arrayList)).isEqualTo(5);
    }
}
