package dev.langchain4j.rag.content.retriever;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.ContentMetadata;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.spi.ServiceHelper;
import dev.langchain4j.spi.model.embedding.EmbeddingModelFactory;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetriever.class */
public class EmbeddingStoreContentRetriever implements ContentRetriever {
    public static final Function<Query, Integer> DEFAULT_MAX_RESULTS = query -> {
        return 3;
    };
    public static final Function<Query, Double> DEFAULT_MIN_SCORE = query -> {
        return Double.valueOf(0.0d);
    };
    public static final Function<Query, Filter> DEFAULT_FILTER = query -> {
        return null;
    };
    public static final String DEFAULT_DISPLAY_NAME = "Default";
    private final EmbeddingStore<TextSegment> embeddingStore;
    private final EmbeddingModel embeddingModel;
    private final Function<Query, Integer> maxResultsProvider;
    private final Function<Query, Double> minScoreProvider;
    private final Function<Query, Filter> filterProvider;
    private final String displayName;

    /* loaded from: input_file:dev/langchain4j/rag/content/retriever/EmbeddingStoreContentRetriever$EmbeddingStoreContentRetrieverBuilder.class */
    public static class EmbeddingStoreContentRetrieverBuilder {
        private String displayName;
        private EmbeddingStore<TextSegment> embeddingStore;
        private EmbeddingModel embeddingModel;
        private Function<Query, Integer> dynamicMaxResults;
        private Function<Query, Double> dynamicMinScore;
        private Function<Query, Filter> dynamicFilter;

        EmbeddingStoreContentRetrieverBuilder() {
        }

        public EmbeddingStoreContentRetrieverBuilder maxResults(Integer num) {
            if (num != null) {
                this.dynamicMaxResults = query -> {
                    return Integer.valueOf(ValidationUtils.ensureGreaterThanZero(num, "maxResults"));
                };
            }
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder minScore(Double d) {
            if (d != null) {
                this.dynamicMinScore = query -> {
                    return Double.valueOf(ValidationUtils.ensureBetween(d, 0.0d, 1.0d, "minScore"));
                };
            }
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder filter(Filter filter) {
            if (filter != null) {
                this.dynamicFilter = query -> {
                    return filter;
                };
            }
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder displayName(String str) {
            this.displayName = str;
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder embeddingStore(EmbeddingStore<TextSegment> embeddingStore) {
            this.embeddingStore = embeddingStore;
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder embeddingModel(EmbeddingModel embeddingModel) {
            this.embeddingModel = embeddingModel;
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder dynamicMaxResults(Function<Query, Integer> function) {
            this.dynamicMaxResults = function;
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder dynamicMinScore(Function<Query, Double> function) {
            this.dynamicMinScore = function;
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder dynamicFilter(Function<Query, Filter> function) {
            this.dynamicFilter = function;
            return this;
        }

        public EmbeddingStoreContentRetriever build() {
            return new EmbeddingStoreContentRetriever(this.displayName, this.embeddingStore, this.embeddingModel, this.dynamicMaxResults, this.dynamicMinScore, this.dynamicFilter);
        }
    }

    public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel) {
        this(DEFAULT_DISPLAY_NAME, embeddingStore, embeddingModel, DEFAULT_MAX_RESULTS, DEFAULT_MIN_SCORE, DEFAULT_FILTER);
    }

    public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel, int i) {
        this(DEFAULT_DISPLAY_NAME, embeddingStore, embeddingModel, query -> {
            return Integer.valueOf(i);
        }, DEFAULT_MIN_SCORE, DEFAULT_FILTER);
    }

    public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel, Integer num, Double d) {
        this(DEFAULT_DISPLAY_NAME, embeddingStore, embeddingModel, query -> {
            return num;
        }, query2 -> {
            return d;
        }, DEFAULT_FILTER);
    }

    private EmbeddingStoreContentRetriever(String str, EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel, Function<Query, Integer> function, Function<Query, Double> function2, Function<Query, Filter> function3) {
        this.displayName = (String) Utils.getOrDefault(str, DEFAULT_DISPLAY_NAME);
        this.embeddingStore = (EmbeddingStore) ValidationUtils.ensureNotNull(embeddingStore, "embeddingStore");
        this.embeddingModel = (EmbeddingModel) ValidationUtils.ensureNotNull((EmbeddingModel) Utils.getOrDefault(embeddingModel, (Supplier<EmbeddingModel>) EmbeddingStoreContentRetriever::loadEmbeddingModel), "embeddingModel");
        this.maxResultsProvider = (Function) Utils.getOrDefault(function, DEFAULT_MAX_RESULTS);
        this.minScoreProvider = (Function) Utils.getOrDefault(function2, DEFAULT_MIN_SCORE);
        this.filterProvider = (Function) Utils.getOrDefault(function3, DEFAULT_FILTER);
    }

    private static EmbeddingModel loadEmbeddingModel() {
        Collection loadFactories = ServiceHelper.loadFactories(EmbeddingModelFactory.class);
        if (loadFactories.size() > 1) {
            throw new RuntimeException("Conflict: multiple embedding models have been found in the classpath. Please explicitly specify the one you wish to use.");
        }
        Iterator it = loadFactories.iterator();
        if (it.hasNext()) {
            return ((EmbeddingModelFactory) it.next()).create();
        }
        return null;
    }

    public static EmbeddingStoreContentRetrieverBuilder builder() {
        return new EmbeddingStoreContentRetrieverBuilder();
    }

    public static EmbeddingStoreContentRetriever from(EmbeddingStore<TextSegment> embeddingStore) {
        return builder().embeddingStore(embeddingStore).build();
    }

    @Override // dev.langchain4j.rag.content.retriever.ContentRetriever
    public List<Content> retrieve(Query query) {
        return (List) this.embeddingStore.search(EmbeddingSearchRequest.builder().queryEmbedding(this.embeddingModel.embed(query.text()).content()).maxResults(this.maxResultsProvider.apply(query)).minScore(this.minScoreProvider.apply(query)).filter(this.filterProvider.apply(query)).build()).matches().stream().map(embeddingMatch -> {
            return Content.from((TextSegment) embeddingMatch.embedded(), Map.of(ContentMetadata.SCORE, embeddingMatch.score(), ContentMetadata.EMBEDDING_ID, embeddingMatch.embeddingId()));
        }).collect(Collectors.toList());
    }

    public String toString() {
        return "EmbeddingStoreContentRetriever{displayName='" + this.displayName + "'}";
    }
}
