package dev.langchain4j.rag.content.retriever;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.util.Arrays;
import java.util.Collections;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

/* loaded from: input_file:dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetrieverTest.class */
class EmbeddingStoreContentRetrieverTest {
    EmbeddingStoreContentRetrieverTest() {
    }

    @Test
    void should_retrieve() {
        EmbeddingStore embeddingStore = (EmbeddingStore) Mockito.mock(EmbeddingStore.class);
        Mockito.when(embeddingStore.findRelevant((Embedding) Mockito.any(), Mockito.anyInt(), Mockito.anyDouble())).thenReturn(Arrays.asList(new EmbeddingMatch(Double.valueOf(0.9d), "id 1", (Embedding) null, TextSegment.from("content 1")), new EmbeddingMatch(Double.valueOf(0.7d), "id 2", (Embedding) null, TextSegment.from("content 2"))));
        Embedding from = Embedding.from(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)));
        EmbeddingModel embeddingModel = (EmbeddingModel) Mockito.mock(EmbeddingModel.class);
        Mockito.when(embeddingModel.embed(Mockito.anyString())).thenReturn(Response.from(from));
        EmbeddingStoreContentRetriever embeddingStoreContentRetriever = new EmbeddingStoreContentRetriever(embeddingStore, embeddingModel);
        Query from2 = Query.from("query");
        Assertions.assertThat(embeddingStoreContentRetriever.retrieve(from2)).containsExactly(new Content[]{Content.from("content 1"), Content.from("content 2")});
        ((EmbeddingModel) Mockito.verify(embeddingModel)).embed(from2.text());
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingModel});
        ((EmbeddingStore) Mockito.verify(embeddingStore)).findRelevant(from, 3, 0.0d);
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingStore});
    }

    @Test
    void should_retrieve_builder() {
        EmbeddingStore embeddingStore = (EmbeddingStore) Mockito.mock(EmbeddingStore.class);
        Mockito.when(embeddingStore.findRelevant((Embedding) Mockito.any(), Mockito.anyInt(), Mockito.anyDouble())).thenReturn(Arrays.asList(new EmbeddingMatch(Double.valueOf(0.9d), "id 1", (Embedding) null, TextSegment.from("content 1")), new EmbeddingMatch(Double.valueOf(0.7d), "id 2", (Embedding) null, TextSegment.from("content 2"))));
        Embedding from = Embedding.from(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)));
        EmbeddingModel embeddingModel = (EmbeddingModel) Mockito.mock(EmbeddingModel.class);
        Mockito.when(embeddingModel.embed(Mockito.anyString())).thenReturn(Response.from(from));
        EmbeddingStoreContentRetriever build = EmbeddingStoreContentRetriever.builder().embeddingStore(embeddingStore).embeddingModel(embeddingModel).build();
        Query from2 = Query.from("query");
        Assertions.assertThat(build.retrieve(from2)).containsExactly(new Content[]{Content.from("content 1"), Content.from("content 2")});
        ((EmbeddingModel) Mockito.verify(embeddingModel)).embed(from2.text());
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingModel});
        ((EmbeddingStore) Mockito.verify(embeddingStore)).findRelevant(from, 3, 0.0d);
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingStore});
    }

    @Test
    void should_retrieve_with_custom_maxResults() {
        EmbeddingStore embeddingStore = (EmbeddingStore) Mockito.mock(EmbeddingStore.class);
        Mockito.when(embeddingStore.findRelevant((Embedding) Mockito.any(), Mockito.anyInt(), Mockito.anyDouble())).thenReturn(Collections.singletonList(new EmbeddingMatch(Double.valueOf(0.9d), "id 1", (Embedding) null, TextSegment.from("content"))));
        Embedding from = Embedding.from(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)));
        EmbeddingModel embeddingModel = (EmbeddingModel) Mockito.mock(EmbeddingModel.class);
        Mockito.when(embeddingModel.embed(Mockito.anyString())).thenReturn(Response.from(from));
        EmbeddingStoreContentRetriever embeddingStoreContentRetriever = new EmbeddingStoreContentRetriever(embeddingStore, embeddingModel, 1);
        Query from2 = Query.from("query");
        Assertions.assertThat(embeddingStoreContentRetriever.retrieve(from2)).containsExactly(new Content[]{Content.from("content")});
        ((EmbeddingModel) Mockito.verify(embeddingModel)).embed(from2.text());
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingModel});
        ((EmbeddingStore) Mockito.verify(embeddingStore)).findRelevant(from, 1, 0.0d);
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingStore});
    }

    @Test
    void should_retrieve_with_custom_maxResults_builder() {
        EmbeddingStore embeddingStore = (EmbeddingStore) Mockito.mock(EmbeddingStore.class);
        Mockito.when(embeddingStore.findRelevant((Embedding) Mockito.any(), Mockito.anyInt(), Mockito.anyDouble())).thenReturn(Collections.singletonList(new EmbeddingMatch(Double.valueOf(0.9d), "id 1", (Embedding) null, TextSegment.from("content"))));
        Embedding from = Embedding.from(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)));
        EmbeddingModel embeddingModel = (EmbeddingModel) Mockito.mock(EmbeddingModel.class);
        Mockito.when(embeddingModel.embed(Mockito.anyString())).thenReturn(Response.from(from));
        EmbeddingStoreContentRetriever build = EmbeddingStoreContentRetriever.builder().embeddingStore(embeddingStore).embeddingModel(embeddingModel).maxResults(1).build();
        Query from2 = Query.from("query");
        Assertions.assertThat(build.retrieve(from2)).containsExactly(new Content[]{Content.from("content")});
        ((EmbeddingModel) Mockito.verify(embeddingModel)).embed(from2.text());
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingModel});
        ((EmbeddingStore) Mockito.verify(embeddingStore)).findRelevant(from, 1, 0.0d);
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingStore});
    }

    @Test
    void should_retrieve_with_custom_minScore() {
        EmbeddingStore embeddingStore = (EmbeddingStore) Mockito.mock(EmbeddingStore.class);
        Mockito.when(embeddingStore.findRelevant((Embedding) Mockito.any(), Mockito.anyInt(), Mockito.anyDouble())).thenReturn(Arrays.asList(new EmbeddingMatch(Double.valueOf(0.9d), "id 1", (Embedding) null, TextSegment.from("content 1")), new EmbeddingMatch(Double.valueOf(0.7d), "id 2", (Embedding) null, TextSegment.from("content 2"))));
        Embedding from = Embedding.from(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)));
        EmbeddingModel embeddingModel = (EmbeddingModel) Mockito.mock(EmbeddingModel.class);
        Mockito.when(embeddingModel.embed(Mockito.anyString())).thenReturn(Response.from(from));
        EmbeddingStoreContentRetriever embeddingStoreContentRetriever = new EmbeddingStoreContentRetriever(embeddingStore, embeddingModel, (Integer) null, Double.valueOf(0.7d));
        Query from2 = Query.from("query");
        Assertions.assertThat(embeddingStoreContentRetriever.retrieve(from2)).containsExactly(new Content[]{Content.from("content 1"), Content.from("content 2")});
        ((EmbeddingModel) Mockito.verify(embeddingModel)).embed(from2.text());
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingModel});
        ((EmbeddingStore) Mockito.verify(embeddingStore)).findRelevant(from, 3, 0.7d);
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingStore});
    }

    @Test
    void should_retrieve_with_custom_minScore_builder() {
        EmbeddingStore embeddingStore = (EmbeddingStore) Mockito.mock(EmbeddingStore.class);
        Mockito.when(embeddingStore.findRelevant((Embedding) Mockito.any(), Mockito.anyInt(), Mockito.anyDouble())).thenReturn(Arrays.asList(new EmbeddingMatch(Double.valueOf(0.9d), "id 1", (Embedding) null, TextSegment.from("content 1")), new EmbeddingMatch(Double.valueOf(0.7d), "id 2", (Embedding) null, TextSegment.from("content 2"))));
        Embedding from = Embedding.from(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f)));
        EmbeddingModel embeddingModel = (EmbeddingModel) Mockito.mock(EmbeddingModel.class);
        Mockito.when(embeddingModel.embed(Mockito.anyString())).thenReturn(Response.from(from));
        EmbeddingStoreContentRetriever build = EmbeddingStoreContentRetriever.builder().embeddingStore(embeddingStore).embeddingModel(embeddingModel).minScore(Double.valueOf(0.7d)).build();
        Query from2 = Query.from("query");
        Assertions.assertThat(build.retrieve(from2)).containsExactly(new Content[]{Content.from("content 1"), Content.from("content 2")});
        ((EmbeddingModel) Mockito.verify(embeddingModel)).embed(from2.text());
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingModel});
        ((EmbeddingStore) Mockito.verify(embeddingStore)).findRelevant(from, 3, 0.7d);
        Mockito.verifyNoMoreInteractions(new Object[]{embeddingStore});
    }
}
