package dev.langchain4j.rag.content.aggregator;

import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

/* loaded from: input_file:dev/langchain4j/rag/content/aggregator/ReRankingContentAggregatorTest.class */
class ReRankingContentAggregatorTest {
    ReRankingContentAggregatorTest() {
    }

    @MethodSource
    @ParameterizedTest
    void should_rerank_when_single_query_and_single_contents(Function<ScoringModel, ContentAggregator> function) {
        Query from = Query.from("query");
        Content from2 = Content.from("content 1");
        Content from3 = Content.from("content 2");
        Map singletonMap = Collections.singletonMap(from, Collections.singletonList(Arrays.asList(from2, from3)));
        ScoringModel scoringModel = (ScoringModel) Mockito.mock(ScoringModel.class);
        Mockito.when(scoringModel.scoreAll((List) ArgumentMatchers.any(), (String) ArgumentMatchers.any())).thenReturn(Response.from(Arrays.asList(Double.valueOf(0.5d), Double.valueOf(0.7d))));
        Assertions.assertThat(function.apply(scoringModel).aggregate(singletonMap)).containsExactly(new Content[]{from3, from2});
    }

    static Stream<Arguments> should_rerank_when_single_query_and_single_contents() {
        return Stream.builder().add(Arguments.of(new Object[]{ReRankingContentAggregator::new})).add(Arguments.of(new Object[]{scoringModel -> {
            return ReRankingContentAggregator.builder().scoringModel(scoringModel).build();
        }})).build();
    }

    @Test
    void should_fuse_then_rerank_when_single_query_and_multiple_contents() {
        Query from = Query.from("query");
        Content from2 = Content.from("content 1");
        Content from3 = Content.from("content");
        Content from4 = Content.from("content");
        Content from5 = Content.from("content 4");
        Map singletonMap = Collections.singletonMap(from, Arrays.asList(Arrays.asList(from2, from3), Arrays.asList(from4, from5)));
        ScoringModel scoringModel = (ScoringModel) Mockito.mock(ScoringModel.class);
        Mockito.when(scoringModel.scoreAll(Arrays.asList(from3.textSegment(), from2.textSegment(), from5.textSegment()), from.text())).thenReturn(Response.from(Arrays.asList(Double.valueOf(0.5d), Double.valueOf(0.7d), Double.valueOf(0.9d))));
        Assertions.assertThat(new ReRankingContentAggregator(scoringModel).aggregate(singletonMap)).containsExactly(new Content[]{from5, from2, from3});
    }

    @Test
    void should_fail_when_multiple_queries_with_default_query_selector() {
        HashMap hashMap = new HashMap();
        hashMap.put(Query.from("query 1"), null);
        hashMap.put(Query.from("query 2"), null);
        ReRankingContentAggregator reRankingContentAggregator = new ReRankingContentAggregator((ScoringModel) Mockito.mock(ScoringModel.class));
        Assertions.assertThatThrownBy(() -> {
            reRankingContentAggregator.aggregate(hashMap);
        }).isExactlyInstanceOf(IllegalArgumentException.class).hasMessage("The 'queryToContents' contains 2 queries, making the re-ranking ambiguous. Because there are multiple queries, it is unclear which one should be used for re-ranking. Please provide a 'querySelector' in the constructor/builder.");
    }

    @Test
    void should_fuse_then_rerank_against_first_query_then_filter_by_minScore() {
        Function function = map -> {
            return (Query) map.keySet().iterator().next();
        };
        Query from = Query.from("query 1");
        Content from2 = Content.from("content");
        Content from3 = Content.from("content 2");
        Content from4 = Content.from("content 3");
        Content from5 = Content.from("content");
        Query from6 = Query.from("query 2");
        Content from7 = Content.from("content 5");
        Content from8 = Content.from("content");
        Content from9 = Content.from("content");
        Content from10 = Content.from("content 8");
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(from, Arrays.asList(Arrays.asList(from2, from3), Arrays.asList(from4, from5)));
        linkedHashMap.put(from6, Arrays.asList(Arrays.asList(from7, from8), Arrays.asList(from9, from10)));
        ScoringModel scoringModel = (ScoringModel) Mockito.mock(ScoringModel.class);
        Mockito.when(scoringModel.scoreAll(Arrays.asList(from2.textSegment(), from4.textSegment(), from7.textSegment(), from3.textSegment(), from10.textSegment()), from.text())).thenReturn(Response.from(Arrays.asList(Double.valueOf(0.6d), Double.valueOf(0.2d), Double.valueOf(0.3d), Double.valueOf(0.4d), Double.valueOf(0.5d))));
        Assertions.assertThat(new ReRankingContentAggregator(scoringModel, function, Double.valueOf(0.4d)).aggregate(linkedHashMap)).containsExactly(new Content[]{from2, from10, from3});
    }

    @Test
    void test_should_got_max_results() {
        Function function = map -> {
            return (Query) map.keySet().iterator().next();
        };
        Query from = Query.from("query 1");
        Content from2 = Content.from("content");
        Content from3 = Content.from("content 2");
        Content from4 = Content.from("content 3");
        Content from5 = Content.from("content");
        Query from6 = Query.from("query 2");
        Content from7 = Content.from("content 5");
        Content from8 = Content.from("content");
        Content from9 = Content.from("content");
        Content from10 = Content.from("content 8");
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(from, Arrays.asList(Arrays.asList(from2, from3), Arrays.asList(from4, from5)));
        linkedHashMap.put(from6, Arrays.asList(Arrays.asList(from7, from8), Arrays.asList(from9, from10)));
        ScoringModel scoringModel = (ScoringModel) Mockito.mock(ScoringModel.class);
        Mockito.when(scoringModel.scoreAll(Arrays.asList(from2.textSegment(), from4.textSegment(), from7.textSegment(), from3.textSegment(), from10.textSegment()), from.text())).thenReturn(Response.from(Arrays.asList(Double.valueOf(0.6d), Double.valueOf(0.2d), Double.valueOf(0.3d), Double.valueOf(0.4d), Double.valueOf(0.5d))));
        Assertions.assertThat(ReRankingContentAggregator.builder().scoringModel(scoringModel).querySelector(function).minScore(Double.valueOf(0.4d)).maxResults(2).build().aggregate(linkedHashMap)).containsExactly(new Content[]{from2, from10});
    }

    @MethodSource
    @ParameterizedTest
    void should_return_empty_list_when_there_is_no_content_to_rerank(Map<Query, Collection<List<Content>>> map) {
        ScoringModel scoringModel = (ScoringModel) Mockito.mock(ScoringModel.class);
        Assertions.assertThat(new ReRankingContentAggregator(scoringModel).aggregate(map)).isEmpty();
        Mockito.verifyNoInteractions(new Object[]{scoringModel});
    }

    private static Stream<Arguments> should_return_empty_list_when_there_is_no_content_to_rerank() {
        return Stream.builder().add(Arguments.of(new Object[]{Collections.emptyMap()})).add(Arguments.of(new Object[]{Collections.singletonMap(Query.from("query"), Collections.emptyList())})).add(Arguments.of(new Object[]{Collections.singletonMap(Query.from("query"), Collections.singletonList(Collections.emptyList()))})).add(Arguments.of(new Object[]{Collections.singletonMap(Query.from("query"), Arrays.asList(Collections.emptyList(), Collections.emptyList()))})).build();
    }
}
