/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.chat;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.ModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.assertj.core.api.AbstractComparableAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.Test;

public abstract class ChatModelListenerIT {
    protected abstract ChatLanguageModel createModel(ChatModelListener var1);

    protected abstract String modelName();

    protected Double temperature() {
        return 0.7;
    }

    protected Double topP() {
        return 1.0;
    }

    protected Integer maxTokens() {
        return 7;
    }

    protected abstract ChatLanguageModel createFailingModel(ChatModelListener var1);

    protected abstract Class<? extends Exception> expectedExceptionClass();

    @Test
    void should_listen_request_and_response() {
        final AtomicReference chatRequestReference = new AtomicReference();
        final AtomicReference chatResponseReference = new AtomicReference();
        final AtomicInteger onRequestInvocations = new AtomicInteger();
        final AtomicReference requestReference = new AtomicReference();
        final AtomicReference responseReference = new AtomicReference();
        final AtomicInteger onResponseInvocations = new AtomicInteger();
        ChatModelListener listener = new ChatModelListener(){

            public void onRequest(ChatModelRequestContext requestContext) {
                chatRequestReference.set(requestContext.chatRequest());
                requestReference.set(requestContext.request());
                onRequestInvocations.incrementAndGet();
                ((AbstractComparableAssert)Assertions.assertThat((Comparable)requestContext.modelProvider()).isNotNull()).isNotEqualTo((Object)ModelProvider.OTHER);
                requestContext.attributes().put("id", "12345");
            }

            public void onResponse(ChatModelResponseContext responseContext) {
                chatResponseReference.set(responseContext.chatResponse());
                responseReference.set(responseContext.response());
                onResponseInvocations.incrementAndGet();
                Assertions.assertThat((Object)responseContext.chatRequest()).isEqualTo(chatRequestReference.get());
                Assertions.assertThat((Object)responseContext.request()).isEqualTo(requestReference.get());
                ((AbstractComparableAssert)Assertions.assertThat((Comparable)responseContext.modelProvider()).isNotNull()).isNotEqualTo((Object)ModelProvider.OTHER);
                Assertions.assertThat((Map)responseContext.attributes()).containsEntry((Object)"id", (Object)"12345");
            }

            public void onError(ChatModelErrorContext errorContext) {
                Assertions.fail((String)("onError() must not be called. Exception: " + errorContext.error().getMessage()));
            }
        };
        ChatLanguageModel model = this.createModel(listener);
        UserMessage userMessage = UserMessage.from((String)"hello");
        ChatRequest.Builder chatRequestBuilder = ChatRequest.builder().messages(new ChatMessage[]{userMessage});
        ToolSpecification toolSpecification = null;
        if (this.supportToolCalls()) {
            toolSpecification = ToolSpecification.builder().name("add").parameters(JsonObjectSchema.builder().addIntegerProperty("a").addIntegerProperty("b").build()).build();
            chatRequestBuilder.toolSpecifications(new ToolSpecification[]{toolSpecification});
        }
        ChatRequest chatRequest = chatRequestBuilder.build();
        AiMessage aiMessage = model.chat(chatRequest).aiMessage();
        ChatRequest observedChatRequest = (ChatRequest)chatRequestReference.get();
        Assertions.assertThat((List)observedChatRequest.messages()).containsExactly((Object[])new ChatMessage[]{userMessage});
        ChatRequestParameters parameters = observedChatRequest.parameters();
        Assertions.assertThat((String)parameters.modelName()).isEqualTo(this.modelName());
        Assertions.assertThat((Double)parameters.temperature()).isCloseTo(this.temperature(), Percentage.withPercentage((double)1.0));
        Assertions.assertThat((Double)parameters.topP()).isEqualTo(this.topP());
        Assertions.assertThat((Integer)parameters.maxOutputTokens()).isEqualTo((Object)this.maxTokens());
        if (this.supportToolCalls()) {
            Assertions.assertThat((List)parameters.toolSpecifications()).containsExactly((Object[])new ToolSpecification[]{toolSpecification});
        }
        Assertions.assertThat((AtomicInteger)onRequestInvocations).hasValue(1);
        ChatModelRequest request = (ChatModelRequest)requestReference.get();
        Assertions.assertThat((String)request.model()).isEqualTo(this.modelName());
        Assertions.assertThat((Double)request.temperature()).isCloseTo(this.temperature(), Percentage.withPercentage((double)1.0));
        Assertions.assertThat((Double)request.topP()).isEqualTo(this.topP());
        Assertions.assertThat((Integer)request.maxTokens()).isEqualTo((Object)this.maxTokens());
        Assertions.assertThat((List)request.messages()).containsExactly((Object[])new ChatMessage[]{userMessage});
        if (this.supportToolCalls()) {
            Assertions.assertThat((List)request.toolSpecifications()).containsExactly((Object[])new ToolSpecification[]{toolSpecification});
        }
        ChatResponse chatResponse = (ChatResponse)chatResponseReference.get();
        Assertions.assertThat((Object)chatResponse.aiMessage()).isEqualTo((Object)aiMessage);
        ChatResponseMetadata metadata = chatResponse.metadata();
        if (this.assertResponseId()) {
            Assertions.assertThat((String)metadata.id()).isNotBlank();
        }
        Assertions.assertThat((String)metadata.modelName()).isNotBlank();
        Assertions.assertThat((Integer)metadata.tokenUsage().inputTokenCount()).isGreaterThan(0);
        Assertions.assertThat((Integer)metadata.tokenUsage().outputTokenCount()).isGreaterThan(0);
        Assertions.assertThat((Integer)metadata.tokenUsage().totalTokenCount()).isGreaterThan(0);
        if (this.assertFinishReason()) {
            Assertions.assertThat((Comparable)metadata.finishReason()).isNotNull();
        }
        Assertions.assertThat((AtomicInteger)onResponseInvocations).hasValue(1);
        ChatModelResponse response2 = (ChatModelResponse)responseReference.get();
        if (this.assertResponseId()) {
            Assertions.assertThat((String)response2.id()).isNotBlank();
        }
        Assertions.assertThat((String)response2.model()).isNotBlank();
        Assertions.assertThat((Integer)response2.tokenUsage().inputTokenCount()).isGreaterThan(0);
        Assertions.assertThat((Integer)response2.tokenUsage().outputTokenCount()).isGreaterThan(0);
        Assertions.assertThat((Integer)response2.tokenUsage().totalTokenCount()).isGreaterThan(0);
        if (this.assertFinishReason()) {
            Assertions.assertThat((Comparable)response2.finishReason()).isNotNull();
        }
        Assertions.assertThat((Object)response2.aiMessage()).isEqualTo((Object)aiMessage);
    }

    protected boolean supportToolCalls() {
        return true;
    }

    protected boolean assertResponseId() {
        return true;
    }

    protected boolean assertFinishReason() {
        return true;
    }

    @Test
    void should_listen_error() {
        final AtomicReference chatRequestReference = new AtomicReference();
        final AtomicReference requestReference = new AtomicReference();
        final AtomicInteger onRequestInvocations = new AtomicInteger();
        final AtomicReference errorReference = new AtomicReference();
        final AtomicInteger onErrorInvocations = new AtomicInteger();
        ChatModelListener listener = new ChatModelListener(){

            public void onRequest(ChatModelRequestContext requestContext) {
                chatRequestReference.set(requestContext.chatRequest());
                requestReference.set(requestContext.request());
                onRequestInvocations.incrementAndGet();
                ((AbstractComparableAssert)Assertions.assertThat((Comparable)requestContext.modelProvider()).isNotNull()).isNotEqualTo((Object)ModelProvider.OTHER);
                requestContext.attributes().put("id", "12345");
            }

            public void onResponse(ChatModelResponseContext responseContext) {
                Assertions.fail((String)"onResponse() must not be called");
            }

            public void onError(ChatModelErrorContext errorContext) {
                errorReference.set(errorContext.error());
                onErrorInvocations.incrementAndGet();
                Assertions.assertThat((Object)errorContext.chatRequest()).isEqualTo(chatRequestReference.get());
                Assertions.assertThat((Object)errorContext.request()).isEqualTo(requestReference.get());
                Assertions.assertThat((Object)errorContext.partialResponse()).isNull();
                ((AbstractComparableAssert)Assertions.assertThat((Comparable)errorContext.modelProvider()).isNotNull()).isNotEqualTo((Object)ModelProvider.OTHER);
                Assertions.assertThat((Map)errorContext.attributes()).containsEntry((Object)"id", (Object)"12345");
            }
        };
        ChatLanguageModel model = this.createFailingModel(listener);
        String userMessage = "this message will fail";
        Exception thrown = null;
        try {
            model.chat(userMessage);
        }
        catch (Exception e) {
            thrown = e;
        }
        Throwable error = (Throwable)errorReference.get();
        Assertions.assertThat((Throwable)error).isExactlyInstanceOf(this.expectedExceptionClass());
        Assertions.assertThat((thrown == error || thrown.getCause() == error ? 1 : 0) != 0).isTrue();
        Assertions.assertThat((AtomicInteger)onRequestInvocations).hasValue(1);
        Assertions.assertThat((AtomicInteger)onErrorInvocations).hasValue(1);
    }
}

