package dev.langchain4j.store.embedding;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import java.util.Arrays;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.assertj.core.data.Percentage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:dev/langchain4j/store/embedding/EmbeddingStoreWithoutMetadataIT.class */
public abstract class EmbeddingStoreWithoutMetadataIT {
    /* JADX INFO: Access modifiers changed from: protected */
    public abstract EmbeddingStore<TextSegment> embeddingStore();

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract EmbeddingModel embeddingModel();

    @BeforeEach
    void beforeEach() {
        clearStore();
        ensureStoreIsEmpty();
    }

    protected void clearStore() {
    }

    protected void ensureStoreIsEmpty() {
        Assertions.assertThat(embeddingStore().findRelevant((Embedding) embeddingModel().embed("hello").content(), 1000)).isEmpty();
    }

    @Test
    void should_add_embedding() {
        Embedding embedding = (Embedding) embeddingModel().embed("hello").content();
        String add = embeddingStore().add(embedding);
        Assertions.assertThat(add).isNotBlank();
        awaitUntilPersisted();
        List findRelevant = embeddingStore().findRelevant(embedding, 10);
        Assertions.assertThat(findRelevant).hasSize(1);
        EmbeddingMatch embeddingMatch = (EmbeddingMatch) findRelevant.get(0);
        Assertions.assertThat(embeddingMatch.score()).isCloseTo(1.0d, Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch.embeddingId()).isEqualTo(add);
        Assertions.assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
        Assertions.assertThat(embeddingMatch.embedded()).isNull();
    }

    @Test
    void should_add_embedding_with_id() {
        String randomUUID = Utils.randomUUID();
        Embedding embedding = (Embedding) embeddingModel().embed("hello").content();
        embeddingStore().add(randomUUID, embedding);
        awaitUntilPersisted();
        List findRelevant = embeddingStore().findRelevant(embedding, 10);
        Assertions.assertThat(findRelevant).hasSize(1);
        EmbeddingMatch embeddingMatch = (EmbeddingMatch) findRelevant.get(0);
        Assertions.assertThat(embeddingMatch.score()).isCloseTo(1.0d, Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch.embeddingId()).isEqualTo(randomUUID);
        Assertions.assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
        Assertions.assertThat(embeddingMatch.embedded()).isNull();
    }

    @Test
    void should_add_embedding_with_segment() {
        TextSegment from = TextSegment.from("hello");
        Embedding embedding = (Embedding) embeddingModel().embed(from.text()).content();
        String add = embeddingStore().add(embedding, from);
        Assertions.assertThat(add).isNotBlank();
        awaitUntilPersisted();
        List findRelevant = embeddingStore().findRelevant(embedding, 10);
        Assertions.assertThat(findRelevant).hasSize(1);
        EmbeddingMatch embeddingMatch = (EmbeddingMatch) findRelevant.get(0);
        Assertions.assertThat(embeddingMatch.score()).isCloseTo(1.0d, Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch.embeddingId()).isEqualTo(add);
        Assertions.assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
        Assertions.assertThat(embeddingMatch.embedded()).isEqualTo(from);
    }

    @Test
    void should_add_multiple_embeddings() {
        Embedding embedding = (Embedding) embeddingModel().embed("hello").content();
        Embedding embedding2 = (Embedding) embeddingModel().embed("hi").content();
        List addAll = embeddingStore().addAll(Arrays.asList(embedding, embedding2));
        Assertions.assertThat(addAll).hasSize(2);
        Assertions.assertThat((String) addAll.get(0)).isNotBlank();
        Assertions.assertThat((String) addAll.get(1)).isNotBlank();
        Assertions.assertThat((String) addAll.get(0)).isNotEqualTo(addAll.get(1));
        awaitUntilPersisted();
        List findRelevant = embeddingStore().findRelevant(embedding, 10);
        Assertions.assertThat(findRelevant).hasSize(2);
        EmbeddingMatch embeddingMatch = (EmbeddingMatch) findRelevant.get(0);
        Assertions.assertThat(embeddingMatch.score()).isCloseTo(1.0d, Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch.embeddingId()).isEqualTo((String) addAll.get(0));
        Assertions.assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
        Assertions.assertThat(embeddingMatch.embedded()).isNull();
        EmbeddingMatch embeddingMatch2 = (EmbeddingMatch) findRelevant.get(1);
        Assertions.assertThat(embeddingMatch2.score()).isCloseTo(RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, embedding2)), Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch2.embeddingId()).isEqualTo((String) addAll.get(1));
        Assertions.assertThat(CosineSimilarity.between(embeddingMatch2.embedding(), embedding2)).isCloseTo(1.0d, Percentage.withPercentage(0.01d));
        Assertions.assertThat(embeddingMatch2.embedded()).isNull();
    }

    @Test
    void should_add_multiple_embeddings_with_segments() {
        TextSegment from = TextSegment.from("hello");
        Embedding embedding = (Embedding) embeddingModel().embed(from.text()).content();
        TextSegment from2 = TextSegment.from("hi");
        Embedding embedding2 = (Embedding) embeddingModel().embed(from2.text()).content();
        List addAll = embeddingStore().addAll(Arrays.asList(embedding, embedding2), Arrays.asList(from, from2));
        Assertions.assertThat(addAll).hasSize(2);
        Assertions.assertThat((String) addAll.get(0)).isNotBlank();
        Assertions.assertThat((String) addAll.get(1)).isNotBlank();
        Assertions.assertThat((String) addAll.get(0)).isNotEqualTo(addAll.get(1));
        awaitUntilPersisted();
        List findRelevant = embeddingStore().findRelevant(embedding, 10);
        Assertions.assertThat(findRelevant).hasSize(2);
        EmbeddingMatch embeddingMatch = (EmbeddingMatch) findRelevant.get(0);
        Assertions.assertThat(embeddingMatch.score()).isCloseTo(1.0d, Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch.embeddingId()).isEqualTo((String) addAll.get(0));
        Assertions.assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
        Assertions.assertThat(embeddingMatch.embedded()).isEqualTo(from);
        EmbeddingMatch embeddingMatch2 = (EmbeddingMatch) findRelevant.get(1);
        Assertions.assertThat(embeddingMatch2.score()).isCloseTo(RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, embedding2)), Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch2.embeddingId()).isEqualTo((String) addAll.get(1));
        Assertions.assertThat(CosineSimilarity.between(embeddingMatch2.embedding(), embedding2)).isCloseTo(1.0d, Percentage.withPercentage(0.01d));
        Assertions.assertThat(embeddingMatch2.embedded()).isEqualTo(from2);
    }

    @Test
    void should_find_with_min_score() {
        String randomUUID = Utils.randomUUID();
        Embedding embedding = (Embedding) embeddingModel().embed("hello").content();
        embeddingStore().add(randomUUID, embedding);
        String randomUUID2 = Utils.randomUUID();
        Embedding embedding2 = (Embedding) embeddingModel().embed("hi").content();
        embeddingStore().add(randomUUID2, embedding2);
        awaitUntilPersisted();
        List findRelevant = embeddingStore().findRelevant(embedding, 10);
        Assertions.assertThat(findRelevant).hasSize(2);
        EmbeddingMatch embeddingMatch = (EmbeddingMatch) findRelevant.get(0);
        Assertions.assertThat(embeddingMatch.score()).isCloseTo(1.0d, Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch.embeddingId()).isEqualTo(randomUUID);
        EmbeddingMatch embeddingMatch2 = (EmbeddingMatch) findRelevant.get(1);
        Assertions.assertThat(embeddingMatch2.score()).isCloseTo(RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, embedding2)), Percentage.withPercentage(1.0d));
        Assertions.assertThat(embeddingMatch2.embeddingId()).isEqualTo(randomUUID2);
        List findRelevant2 = embeddingStore().findRelevant(embedding, 10, embeddingMatch2.score().doubleValue() - 0.01d);
        Assertions.assertThat(findRelevant2).hasSize(2);
        Assertions.assertThat(((EmbeddingMatch) findRelevant2.get(0)).embeddingId()).isEqualTo(randomUUID);
        Assertions.assertThat(((EmbeddingMatch) findRelevant2.get(1)).embeddingId()).isEqualTo(randomUUID2);
        List findRelevant3 = embeddingStore().findRelevant(embedding, 10, embeddingMatch2.score().doubleValue());
        Assertions.assertThat(findRelevant3).hasSize(2);
        Assertions.assertThat(((EmbeddingMatch) findRelevant3.get(0)).embeddingId()).isEqualTo(randomUUID);
        Assertions.assertThat(((EmbeddingMatch) findRelevant3.get(1)).embeddingId()).isEqualTo(randomUUID2);
        List findRelevant4 = embeddingStore().findRelevant(embedding, 10, embeddingMatch2.score().doubleValue() + 0.01d);
        Assertions.assertThat(findRelevant4).hasSize(1);
        Assertions.assertThat(((EmbeddingMatch) findRelevant4.get(0)).embeddingId()).isEqualTo(randomUUID);
    }

    @Test
    void should_return_correct_score() {
        Embedding embedding = (Embedding) embeddingModel().embed("hello").content();
        Assertions.assertThat(embeddingStore().add(embedding)).isNotBlank();
        awaitUntilPersisted();
        Embedding embedding2 = (Embedding) embeddingModel().embed("hi").content();
        List findRelevant = embeddingStore().findRelevant(embedding2, 1);
        Assertions.assertThat(findRelevant).hasSize(1);
        Assertions.assertThat(((EmbeddingMatch) findRelevant.get(0)).score()).isCloseTo(RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, embedding2)), Percentage.withPercentage(1.0d));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void awaitUntilPersisted() {
    }
}
