package dev.langchain4j.store.embedding.mongodb;

import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCommandException;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.CreateCollectionOptions;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Projections;
import com.mongodb.client.model.search.SearchPath;
import com.mongodb.client.model.search.VectorSearchOptions;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.bson.codecs.configuration.CodecProvider;
import org.bson.codecs.configuration.CodecRegistries;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.codecs.pojo.PojoCodecProvider;
import org.bson.conversions.Bson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStore.class */
public class MongoDbEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(MongoDbEmbeddingStore.class);
    private final MongoCollection<MongoDbDocument> collection;
    private final String indexName;
    private final long maxResultRatio;
    private final VectorSearchOptions vectorSearchOptions;

    /* loaded from: input_file:dev/langchain4j/store/embedding/mongodb/MongoDbEmbeddingStore$Builder.class */
    public static class Builder {
        private MongoClient mongoClient;
        private String databaseName;
        private String collectionName;
        private String indexName;
        private Long maxResultRatio;
        private CreateCollectionOptions createCollectionOptions;
        private Bson filter;
        private IndexMapping indexMapping;
        private Boolean createIndex;

        public Builder fromClient(MongoClient mongoClient) {
            this.mongoClient = mongoClient;
            return this;
        }

        public Builder databaseName(String str) {
            this.databaseName = str;
            return this;
        }

        public Builder collectionName(String str) {
            this.collectionName = str;
            return this;
        }

        public Builder indexName(String str) {
            this.indexName = str;
            return this;
        }

        public Builder maxResultRatio(Long l) {
            this.maxResultRatio = l;
            return this;
        }

        public Builder createCollectionOptions(CreateCollectionOptions createCollectionOptions) {
            this.createCollectionOptions = createCollectionOptions;
            return this;
        }

        public Builder filter(Bson bson) {
            this.filter = bson;
            return this;
        }

        public Builder indexMapping(IndexMapping indexMapping) {
            this.indexMapping = indexMapping;
            return this;
        }

        public Builder createIndex(Boolean bool) {
            this.createIndex = bool;
            return this;
        }

        public MongoDbEmbeddingStore build() {
            return new MongoDbEmbeddingStore(this.mongoClient, this.databaseName, this.collectionName, this.indexName, this.maxResultRatio, this.createCollectionOptions, this.filter, this.indexMapping, this.createIndex);
        }
    }

    public MongoDbEmbeddingStore(MongoClient mongoClient, String str, String str2, String str3, Long l, CreateCollectionOptions createCollectionOptions, Bson bson, IndexMapping indexMapping, Boolean bool) {
        String str4 = (String) ValidationUtils.ensureNotNull(str, "databaseName");
        String str5 = (String) ValidationUtils.ensureNotNull(str2, "collectionName");
        Boolean bool2 = (Boolean) Utils.getOrDefault(bool, false);
        this.indexName = (String) ValidationUtils.ensureNotNull(str3, "indexName");
        this.maxResultRatio = ((Long) Utils.getOrDefault(l, 10L)).longValue();
        CodecRegistry fromRegistries = CodecRegistries.fromRegistries(new CodecRegistry[]{MongoClientSettings.getDefaultCodecRegistry(), CodecRegistries.fromProviders(new CodecProvider[]{PojoCodecProvider.builder().register(new Class[]{MongoDbDocument.class, MongoDbMatchedDocument.class}).build()})});
        MongoDatabase database = mongoClient.getDatabase(str4);
        if (!isCollectionExist(database, str5)) {
            createCollection(database, str5, (CreateCollectionOptions) Utils.getOrDefault(createCollectionOptions, new CreateCollectionOptions()));
        }
        this.collection = database.getCollection(str5, MongoDbDocument.class).withCodecRegistry(fromRegistries);
        this.vectorSearchOptions = bson == null ? VectorSearchOptions.vectorSearchOptions() : VectorSearchOptions.vectorSearchOptions().filter(bson);
        if (!Boolean.TRUE.equals(bool2) || isIndexExist(this.indexName)) {
            return;
        }
        createIndex(this.indexName, (IndexMapping) Utils.getOrDefault(indexMapping, IndexMapping.defaultIndexMapping()));
    }

    public static Builder builder() {
        return new Builder();
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list3, list, list2);
        return list3;
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        try {
            return (List) StreamSupport.stream(this.collection.aggregate(Arrays.asList(Aggregates.vectorSearch(SearchPath.fieldPath("embedding"), (List) embedding.vectorAsList().stream().map((v0) -> {
                return v0.doubleValue();
            }).collect(Collectors.toList()), this.indexName, i * this.maxResultRatio, i, this.vectorSearchOptions), Aggregates.project(Projections.fields(new Bson[]{Projections.metaVectorSearchScore("score"), Projections.include(new String[]{"embedding", "metadata", "text"})})), Aggregates.match(Filters.gte("score", Double.valueOf(d)))), MongoDbMatchedDocument.class).spliterator(), false).map(MappingUtils::toEmbeddingMatch).collect(Collectors.toList());
        } catch (MongoCommandException e) {
            if (log.isErrorEnabled()) {
                log.error("Error in MongoDBEmbeddingStore.findRelevant", e);
            }
            throw new RuntimeException((Throwable) e);
        }
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("do not add empty embeddings to MongoDB Atlas");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        ArrayList arrayList = new ArrayList(list.size());
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(MappingUtils.toMongoDbDocument(list.get(i), list2.get(i), list3 == null ? null : list3.get(i)));
        }
        if (this.collection.insertMany(arrayList).wasAcknowledged() || !log.isWarnEnabled()) {
            return;
        }
        String format = String.format("[MongoDbEmbeddingStore] Add document failed, Document=%s", arrayList);
        log.warn(format);
        throw new RuntimeException(format);
    }

    private boolean isCollectionExist(MongoDatabase mongoDatabase, String str) {
        Stream stream = StreamSupport.stream(mongoDatabase.listCollectionNames().spliterator(), false);
        Objects.requireNonNull(str);
        return stream.anyMatch((v1) -> {
            return r1.equals(v1);
        });
    }

    private void createCollection(MongoDatabase mongoDatabase, String str, CreateCollectionOptions createCollectionOptions) {
        mongoDatabase.createCollection(str, createCollectionOptions);
    }

    private boolean isIndexExist(String str) {
        return StreamSupport.stream(this.collection.listSearchIndexes().spliterator(), false).anyMatch(document -> {
            return str.equals(document.getString("name"));
        });
    }

    private void createIndex(String str, IndexMapping indexMapping) {
        this.collection.createSearchIndex(str, MappingUtils.fromIndexMapping(indexMapping));
    }
}
