package com.gengoai.apollo.ml.model.embedding;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.math.statistics.measure.Measure;
import com.gengoai.apollo.math.statistics.measure.Similarity;
import com.gengoai.collection.Sets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:com/gengoai/apollo/ml/model/embedding/VSQuery.class */
public final class VSQuery {
    private int K = Integer.MAX_VALUE;
    private Measure measure = Similarity.Cosine;
    private List<String> negativeTerms = new ArrayList();
    private List<NDArray> negativeVectors = new ArrayList();
    private List<String> positiveTerms = new ArrayList();
    private List<NDArray> positiveVectors = new ArrayList();
    private double threshold = Double.NEGATIVE_INFINITY;

    public static VSQuery compositeTermQuery(Iterable<String> iterable, Iterable<String> iterable2) {
        return new VSQuery().positiveTerms(iterable).negativeTerms(iterable2);
    }

    public static VSQuery compositeVectorQuery(Iterable<NDArray> iterable, Iterable<NDArray> iterable2) {
        return new VSQuery().positiveVectors(iterable).negativeVectors(iterable2);
    }

    public static VSQuery termQuery(String str) {
        return new VSQuery().query(str);
    }

    public static VSQuery vectorQuery(NDArray nDArray) {
        return new VSQuery().query(nDArray);
    }

    public Stream<NDArray> applyFilters(Stream<NDArray> stream) {
        if (Double.isFinite(threshold())) {
            stream = stream.filter(nDArray -> {
                return measure().getOptimum().test(nDArray.getWeight(), threshold());
            });
        }
        Stream<NDArray> sorted = stream.sorted((nDArray2, nDArray3) -> {
            return measure().getOptimum().compare(nDArray2.getWeight(), nDArray3.getWeight());
        });
        Set<String> excludedLabels = getExcludedLabels();
        if (excludedLabels.size() > 0) {
            sorted = sorted.filter(nDArray4 -> {
                return !excludedLabels.contains(nDArray4.getLabel());
            });
        }
        if (limit() > 0 && limit() < Integer.MAX_VALUE) {
            sorted = sorted.limit(limit());
        }
        return sorted;
    }

    public void clearQuery() {
        this.positiveVectors.clear();
        this.negativeVectors.clear();
        this.positiveTerms.clear();
        this.negativeTerms.clear();
    }

    public Set<String> getExcludedLabels() {
        return Sets.union(this.positiveTerms, this.negativeTerms);
    }

    private NDArray getVector(List<NDArray> list, List<String> list2, WordEmbedding wordEmbedding) {
        Stream<NDArray> map;
        if (list.size() > 0) {
            map = list.stream();
        } else {
            Stream<String> stream = list2.stream();
            Objects.requireNonNull(wordEmbedding);
            map = stream.map(wordEmbedding::embed);
        }
        return map.reduce(NDArrayFactory.DENSE.array(wordEmbedding.dimension()), (v0, v1) -> {
            return v0.addi(v1);
        });
    }

    public VSQuery limit(int i) {
        this.K = i;
        return this;
    }

    public int limit() {
        return this.K;
    }

    public VSQuery measure(Measure measure) {
        this.measure = measure;
        return this;
    }

    public Measure measure() {
        return this.measure;
    }

    public VSQuery negativeTerms(String... strArr) {
        this.negativeVectors.clear();
        this.negativeTerms.clear();
        Collections.addAll(this.negativeTerms, strArr);
        return this;
    }

    public VSQuery negativeTerms(Iterable<String> iterable) {
        this.negativeVectors.clear();
        this.negativeTerms.clear();
        List<String> list = this.negativeTerms;
        Objects.requireNonNull(list);
        iterable.forEach((v1) -> {
            r1.add(v1);
        });
        return this;
    }

    public VSQuery negativeVectors(NDArray... nDArrayArr) {
        this.negativeVectors.clear();
        this.negativeTerms.clear();
        Collections.addAll(this.negativeVectors, nDArrayArr);
        return this;
    }

    public VSQuery negativeVectors(Iterable<NDArray> iterable) {
        this.negativeVectors.clear();
        this.negativeTerms.clear();
        List<NDArray> list = this.negativeVectors;
        Objects.requireNonNull(list);
        iterable.forEach((v1) -> {
            r1.add(v1);
        });
        return this;
    }

    public VSQuery positiveTerms(String... strArr) {
        this.positiveVectors.clear();
        this.positiveTerms.clear();
        Collections.addAll(this.positiveTerms, strArr);
        return this;
    }

    public VSQuery positiveTerms(Iterable<String> iterable) {
        this.positiveVectors.clear();
        this.positiveTerms.clear();
        List<String> list = this.positiveTerms;
        Objects.requireNonNull(list);
        iterable.forEach((v1) -> {
            r1.add(v1);
        });
        return this;
    }

    public VSQuery positiveVectors(NDArray... nDArrayArr) {
        this.positiveVectors.clear();
        this.positiveTerms.clear();
        Collections.addAll(this.positiveVectors, nDArrayArr);
        return this;
    }

    public VSQuery positiveVectors(Iterable<NDArray> iterable) {
        this.positiveVectors.clear();
        this.positiveTerms.clear();
        List<NDArray> list = this.positiveVectors;
        Objects.requireNonNull(list);
        iterable.forEach((v1) -> {
            r1.add(v1);
        });
        return this;
    }

    public VSQuery query(NDArray... nDArrayArr) {
        clearQuery();
        Collections.addAll(this.positiveVectors, nDArrayArr);
        return this;
    }

    public VSQuery query(String... strArr) {
        clearQuery();
        Collections.addAll(this.positiveTerms, strArr);
        return this;
    }

    public NDArray queryVector(WordEmbedding wordEmbedding) {
        return getVector(this.positiveVectors, this.positiveTerms, wordEmbedding).subi(getVector(this.negativeVectors, this.negativeTerms, wordEmbedding));
    }

    public VSQuery threshold(double d) {
        this.threshold = d;
        return this;
    }

    public double threshold() {
        return this.threshold;
    }
}
