package ai.knowly.langtorch.store.vectordb;

import ai.knowly.langtorch.processor.EmbeddingProcessor;
import ai.knowly.langtorch.schema.embeddings.EmbeddingInput;
import ai.knowly.langtorch.schema.io.DomainDocument;
import ai.knowly.langtorch.schema.io.Metadata;
import ai.knowly.langtorch.store.vectordb.integration.VectorStore;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.PGVectorService;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.SqlCommandProvider;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorQueryParameters;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorStoreSpec;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.PGVectorValues;
import ai.knowly.langtorch.store.vectordb.integration.pgvector.schema.distance.DistanceStrategy;
import ai.knowly.langtorch.store.vectordb.integration.schema.SimilaritySearchQuery;
import com.google.common.flogger.FluentLogger;
import com.google.common.primitives.Floats;
import com.google.inject.Inject;
import com.pgvector.PGvector;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import lombok.NonNull;

/* loaded from: input_file:ai/knowly/langtorch/store/vectordb/PGVectorStore.class */
public class PGVectorStore implements VectorStore {
    private static final int EMBEDDINGS_COLUMN_COUNT = 2;
    private static final int EMBEDDINGS_INDEX_ID = 0;
    private static final int EMBEDDINGS_INDEX_VECTOR = 1;
    private static final int METADATA_COLUMN_COUNT = 4;
    private static final int METADATA_INDEX_ID = 0;
    private static final int METADATA_INDEX_KEY = 1;
    private static final int METADATA_INDEX_VALUE = 2;
    private static final int METADATA_INDEX_VECTOR_ID = 3;
    private static final FluentLogger logger = FluentLogger.forEnclosingClass();

    @NonNull
    private final EmbeddingProcessor embeddingsProcessor;
    private final PGVectorStoreSpec pgVectorStoreSpec;
    private final SqlCommandProvider sqlCommandProvider;

    @NonNull
    private final PGVectorService pgVectorService;
    private final DistanceStrategy distanceStrategy;

    @Inject
    public PGVectorStore(@NonNull EmbeddingProcessor embeddingProcessor, PGVectorStoreSpec pGVectorStoreSpec, @NonNull PGVectorService pGVectorService, DistanceStrategy distanceStrategy) throws SQLException {
        if (embeddingProcessor == null) {
            throw new NullPointerException("embeddingsProcessor is marked non-null but is null");
        }
        if (pGVectorService == null) {
            throw new NullPointerException("pgVectorService is marked non-null but is null");
        }
        this.distanceStrategy = distanceStrategy;
        this.pgVectorService = pGVectorService;
        this.embeddingsProcessor = embeddingProcessor;
        this.pgVectorStoreSpec = pGVectorStoreSpec;
        this.sqlCommandProvider = new SqlCommandProvider(pGVectorStoreSpec.getDatabaseName(), pGVectorStoreSpec.isOverwriteExistingTables());
        createNecessaryTables();
    }

    private void createNecessaryTables() throws SQLException {
        createEmbeddingsTable();
        createMetadataTable();
    }

    @Override // ai.knowly.langtorch.store.vectordb.integration.VectorStore
    public boolean addDocuments(List<DomainDocument> list) {
        if (list.isEmpty()) {
            return true;
        }
        PGVectorQueryParameters vectorQueryParameters = getVectorQueryParameters(list);
        List<PGVectorValues> vectorValues = vectorQueryParameters.getVectorValues();
        try {
            PreparedStatement prepareStatement = this.pgVectorService.prepareStatement(this.sqlCommandProvider.getInsertEmbeddingsQuery(vectorQueryParameters.getVectorParameters()));
            PreparedStatement prepareStatement2 = this.pgVectorService.prepareStatement(this.sqlCommandProvider.getInsertMetadataQuery(vectorQueryParameters.getMetadataParameters()));
            setQueryParameters(vectorValues, prepareStatement, prepareStatement2);
            return prepareStatement.executeUpdate() == vectorValues.size() && prepareStatement2.executeUpdate() == vectorQueryParameters.getMetadataSize();
        } catch (SQLException e) {
            logger.atSevere().withCause(e).log("Error with SQL Exception");
            return false;
        }
    }

    @Override // ai.knowly.langtorch.store.vectordb.integration.VectorStore
    public List<DomainDocument> similaritySearch(SimilaritySearchQuery similaritySearchQuery) {
        float[] floatVectorValues = getFloatVectorValues(similaritySearchQuery.getQuery());
        double[] doubleVectorValues = getDoubleVectorValues(floatVectorValues);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        try {
            PreparedStatement prepareStatement = this.pgVectorService.prepareStatement(this.sqlCommandProvider.getSelectEmbeddingsQuery(this.distanceStrategy.getSyntax(), similaritySearchQuery.getTopK().longValue()));
            prepareStatement.setObject(1, new PGvector(floatVectorValues));
            ResultSet executeQuery = prepareStatement.executeQuery();
            while (executeQuery.next()) {
                String str = (String) executeQuery.getObject(1);
                PGvector pGvector = (PGvector) executeQuery.getObject(2);
                String str2 = (String) executeQuery.getObject(METADATA_INDEX_VECTOR_ID);
                String str3 = (String) executeQuery.getObject(METADATA_COLUMN_COUNT);
                double calculateDistance = this.distanceStrategy.calculateDistance(doubleVectorValues, getDoubleVectorValues(pGvector.toArray()));
                linkedHashMap.computeIfAbsent(str, str4 -> {
                    return DomainDocument.builder().setId(str).setPageContent("").setSimilarityScore(Optional.of(Double.valueOf(calculateDistance))).setMetadata(Metadata.builder().build()).build();
                });
                DomainDocument domainDocument = (DomainDocument) linkedHashMap.get(str);
                saveValueToMetadataIfPresent(domainDocument, str2, str3);
                linkedHashMap.put(str, getDocumentWithScoreWithPageContent(domainDocument, str2, str3));
            }
            return new ArrayList(linkedHashMap.values());
        } catch (SQLException e) {
            logger.atSevere().withCause(e).log("Error with SQL Exception");
            return new ArrayList(linkedHashMap.values());
        }
    }

    private void createEmbeddingsTable() throws SQLException {
        this.pgVectorService.executeUpdate(this.sqlCommandProvider.getCreateEmbeddingsTableQuery(this.pgVectorStoreSpec.getVectorDimensions()));
    }

    private void createMetadataTable() throws SQLException {
        this.pgVectorService.executeUpdate(this.sqlCommandProvider.getCreateMetadataTableQuery());
    }

    private PGVectorQueryParameters getVectorQueryParameters(List<DomainDocument> list) {
        ArrayList arrayList = new ArrayList();
        StringBuilder sb = new StringBuilder();
        StringBuilder sb2 = new StringBuilder();
        int i = 0;
        for (DomainDocument domainDocument : list) {
            arrayList.add(buildPGVectorValues(domainDocument.getId().orElse(UUID.randomUUID().toString()), createVector(domainDocument), domainDocument.getMetadata()));
            sb.append(getVectorParameters());
            i += processMetadata(sb2, domainDocument.getMetadata());
        }
        trimStringBuilder(sb);
        trimStringBuilder(sb2);
        return buildPGVectorQueryParameters(arrayList, sb.toString(), sb2.toString(), i);
    }

    private PGVectorValues buildPGVectorValues(String str, List<Double> list, Optional<Metadata> optional) {
        return PGVectorValues.builder().setId(str).setValues(getFloatVectorValues(list)).setMetadata(optional.orElse(Metadata.builder().build())).build();
    }

    private String getVectorParameters() {
        return "(?, ?), ";
    }

    private int processMetadata(StringBuilder sb, Optional<Metadata> optional) {
        if (!optional.isPresent()) {
            return 0;
        }
        int size = 0 + optional.get().getValue().size();
        for (int i = 0; i < optional.get().getValue().entrySet().size(); i++) {
            sb.append("(?, ?, ?, ?), ");
        }
        return size;
    }

    private void trimStringBuilder(StringBuilder sb) {
        int lastIndexOf = sb.lastIndexOf(", ");
        if (lastIndexOf > 0) {
            sb.delete(lastIndexOf, sb.length());
        }
    }

    private PGVectorQueryParameters buildPGVectorQueryParameters(List<PGVectorValues> list, String str, String str2, int i) {
        return PGVectorQueryParameters.builder().setVectorValues(list).setVectorParameters(str).setMetadataParameters(str2).setMetadataSize(i).build();
    }

    private List<Double> createVector(DomainDocument domainDocument) {
        return this.embeddingsProcessor.run(EmbeddingInput.builder().setModel(this.pgVectorStoreSpec.getModel()).setInput(Collections.singletonList(domainDocument.getPageContent())).build()).getValue().get(0).getVector();
    }

    private int setMetadataQueryParameters(PGVectorValues pGVectorValues, int i, PreparedStatement preparedStatement) throws SQLException {
        for (Map.Entry<String, String> entry : pGVectorValues.getMetadata().getValue().entrySet()) {
            for (int i2 = 0; i2 < METADATA_COLUMN_COUNT; i2++) {
                switch (i2) {
                    case 0:
                        preparedStatement.setString(i, pGVectorValues.getId() + entry.getKey());
                        break;
                    case 1:
                        preparedStatement.setString(i, entry.getKey());
                        break;
                    case 2:
                        preparedStatement.setString(i, entry.getValue());
                        break;
                    case METADATA_INDEX_VECTOR_ID /* 3 */:
                        preparedStatement.setString(i, pGVectorValues.getId());
                        break;
                    default:
                        logger.atSevere().log("INVALID COLUM INDEX");
                        break;
                }
                i++;
            }
        }
        return i;
    }

    private int setVectorQueryParameters(PGVectorValues pGVectorValues, int i, PreparedStatement preparedStatement) throws SQLException {
        for (int i2 = 0; i2 < 2; i2++) {
            if (i2 == 0) {
                preparedStatement.setString(i, pGVectorValues.getId());
            } else if (i2 == 1) {
                preparedStatement.setObject(i, new PGvector(pGVectorValues.getValues()));
            }
            i++;
        }
        return i;
    }

    private void setQueryParameters(List<PGVectorValues> list, PreparedStatement preparedStatement, PreparedStatement preparedStatement2) throws SQLException {
        int i = 1;
        int i2 = 1;
        for (PGVectorValues pGVectorValues : list) {
            i = setVectorQueryParameters(pGVectorValues, i, preparedStatement);
            i2 = setMetadataQueryParameters(pGVectorValues, i2, preparedStatement2);
        }
    }

    private void saveValueToMetadataIfPresent(DomainDocument domainDocument, String str, String str2) {
        Optional<Metadata> metadata = domainDocument.getMetadata();
        if (!metadata.isPresent() || str == null) {
            return;
        }
        metadata.get().getValue().put(str, str2);
    }

    private DomainDocument getDocumentWithScoreWithPageContent(DomainDocument domainDocument, String str, String str2) {
        if (str == null) {
            return domainDocument;
        }
        Optional<String> textKey = this.pgVectorStoreSpec.getTextKey();
        if (textKey.isPresent() && str.equals(textKey.get())) {
            return domainDocument.toBuilder().setPageContent(str2).build();
        }
        return domainDocument;
    }

    private float[] getFloatVectorValues(List<Double> list) {
        return Floats.toArray(list);
    }

    private double[] getDoubleVectorValues(float[] fArr) {
        double[] dArr = new double[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            dArr[i] = fArr[i];
        }
        return dArr;
    }
}
