package org.linqs.psl.database.rdbms;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.DatabaseQuery;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.formula.Conjunction;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.predicate.ExternalFunctionalPredicate;
import org.linqs.psl.model.predicate.GroundingOnlyPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.term.Term;
import org.linqs.psl.model.term.Variable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/database/rdbms/QueryRewriter.class */
public class QueryRewriter {
    private static final Logger log = LoggerFactory.getLogger(QueryRewriter.class);
    private double allowedTotalCostIncrease = Options.QR_ALLOWED_TOTAL_INCREASE.getDouble();
    private double allowedStepCostIncrease = Options.QR_ALLOWED_STEP_INCREASE.getDouble();
    private CostEstimator costEstimator = CostEstimator.valueOf(Options.QR_COST_ESTIMATOR.getString().toUpperCase());

    /* loaded from: input_file:org/linqs/psl/database/rdbms/QueryRewriter$CostEstimator.class */
    public enum CostEstimator {
        SIZE,
        SELECTIVITY,
        HISTOGRAM
    }

    public Formula rewrite(Formula formula, RDBMSDataStore rDBMSDataStore) {
        DatabaseQuery.validate(formula);
        if (formula instanceof Atom) {
            return formula;
        }
        Set<Atom> atoms = formula.getAtoms(new HashSet());
        Set<Atom> filterBaseAtoms = filterBaseAtoms(atoms);
        Map<Predicate, TableStats> fetchTableStats = fetchTableStats(atoms, rDBMSDataStore);
        double estimateQuerySize = estimateQuerySize(this.costEstimator, atoms, null, fetchTableStats, rDBMSDataStore);
        double d = estimateQuerySize;
        log.trace("Starting cost: " + estimateQuerySize);
        while (true) {
            double d2 = -1.0d;
            Atom atom = null;
            for (Atom atom2 : atoms) {
                if (canRemove(atom2, atoms)) {
                    double estimateQuerySize2 = estimateQuerySize(this.costEstimator, atoms, atom2, fetchTableStats, rDBMSDataStore);
                    if (estimateQuerySize2 < 0.0d) {
                        log.trace("Planned Cost for (" + atoms + " - " + atom2 + "): MAX");
                    } else {
                        log.trace("Planned Cost for (" + atoms + " - " + atom2 + "): " + estimateQuerySize2);
                        if (atom == null || estimateQuerySize2 < d2) {
                            atom = atom2;
                            d2 = estimateQuerySize2;
                        }
                    }
                }
            }
            if (atom != null && atom != null && d2 <= estimateQuerySize * this.allowedTotalCostIncrease && d2 <= d * this.allowedStepCostIncrease) {
                atoms.remove(atom);
                d = d2;
                log.trace("Choose plan for iteration: " + atoms + ": " + d2);
            }
        }
        atoms.addAll(filterBaseAtoms);
        Formula next = atoms.size() == 1 ? atoms.iterator().next() : new Conjunction((Formula[]) atoms.toArray(new Formula[0]));
        log.debug("Computed cost-based query rewrite for [{}]({}): [{}]({}).", new Object[]{formula, Double.valueOf(estimateQuerySize), next, Double.valueOf(d)});
        return next;
    }

    private double estimateQuerySize(CostEstimator costEstimator, Set<Atom> set, Atom atom, Map<Predicate, TableStats> map, RDBMSDataStore rDBMSDataStore) {
        if (costEstimator == CostEstimator.HISTOGRAM) {
            return estimateQuerySizeWithHistorgram(set, atom, map, rDBMSDataStore);
        }
        if (costEstimator == CostEstimator.SELECTIVITY) {
            return estimateQuerySizeWithSelectivity(set, atom, map, rDBMSDataStore);
        }
        if (costEstimator == CostEstimator.SIZE) {
            return estimateQuerySizeWithSize(set, atom, map, rDBMSDataStore);
        }
        throw new IllegalStateException("Unknown CostEstimator value: " + costEstimator);
    }

    private double estimateQuerySizeWithHistorgram(Set<Atom> set, Atom atom, Map<Predicate, TableStats> map, RDBMSDataStore rDBMSDataStore) {
        double d = 1.0d;
        Iterator<Atom> it = set.iterator();
        while (it.hasNext()) {
            if (it.next() != atom) {
                d *= map.get(r0.getPredicate()).getCount();
            }
        }
        for (Map.Entry<Variable, Set<Atom>> entry : getAllUsedVariables(set, null).entrySet()) {
            Variable key = entry.getKey();
            Set<Atom> value = entry.getValue();
            if (value.size() > 1) {
                SelectivityHistogram selectivityHistogram = null;
                double d2 = 1.0d;
                for (Atom atom2 : value) {
                    if (atom2 != atom) {
                        d2 *= map.get(atom2.getPredicate()).getCount();
                        SelectivityHistogram histogram = map.get(atom2.getPredicate()).getHistogram(getColumnName(rDBMSDataStore, atom2, key));
                        selectivityHistogram = selectivityHistogram == null ? histogram : selectivityHistogram.join(histogram);
                    }
                }
                if (selectivityHistogram.size() == Long.MAX_VALUE) {
                    return -1.0d;
                }
                d *= selectivityHistogram.size() / d2;
            }
        }
        return d;
    }

    private double estimateQuerySizeWithSelectivity(Set<Atom> set, Atom atom, Map<Predicate, TableStats> map, RDBMSDataStore rDBMSDataStore) {
        double d = 1.0d;
        Iterator<Atom> it = set.iterator();
        while (it.hasNext()) {
            if (it.next() != atom) {
                d *= map.get(r0.getPredicate()).getCount();
            }
        }
        for (Variable variable : getAllUsedVariables(set, null).keySet()) {
            int i = 0;
            int i2 = 0;
            for (Atom atom2 : set) {
                if (atom2 != atom && atom2.getVariables().contains(variable)) {
                    int cardinality = map.get(atom2.getPredicate()).getCardinality(getColumnName(rDBMSDataStore, atom2, variable));
                    if (i == 0 || cardinality < i2) {
                        i2 = cardinality;
                    }
                    i++;
                }
            }
            if (i > 1) {
                d /= i2;
            }
        }
        return d;
    }

    private double estimateQuerySizeWithSize(Set<Atom> set, Atom atom, Map<Predicate, TableStats> map, RDBMSDataStore rDBMSDataStore) {
        double d = 1.0d;
        Iterator<Atom> it = set.iterator();
        while (it.hasNext()) {
            if (it.next() != atom) {
                d *= map.get(r0.getPredicate()).getCount();
            }
        }
        return d;
    }

    private String getColumnName(RDBMSDataStore rDBMSDataStore, Atom atom, Variable variable) {
        int i = -1;
        Term[] arguments = atom.getArguments();
        int i2 = 0;
        while (true) {
            if (i2 >= arguments.length) {
                break;
            }
            if (variable.equals(arguments[i2])) {
                i = i2;
                break;
            }
            i2++;
        }
        if (i == -1) {
            throw new NoSuchElementException(String.format("Could not find column name for variable %s in atom %s.", variable, atom));
        }
        return rDBMSDataStore.getPredicateInfo(atom.getPredicate()).argumentColumns().get(i);
    }

    private boolean canRemove(Atom atom, Set<Atom> set) {
        Set<Variable> variables = atom.getVariables();
        variables.removeAll(getAllUsedVariables(set, atom).keySet());
        return variables.size() == 0;
    }

    private Map<Variable, Set<Atom>> getAllUsedVariables(Set<Atom> set, Atom atom) {
        HashMap hashMap = new HashMap();
        for (Atom atom2 : set) {
            if (atom2 != atom) {
                for (Variable variable : atom2.getVariables()) {
                    if (!hashMap.containsKey(variable)) {
                        hashMap.put(variable, new HashSet());
                    }
                    ((Set) hashMap.get(variable)).add(atom2);
                }
            }
        }
        return hashMap;
    }

    private Map<Predicate, TableStats> fetchTableStats(Set<Atom> set, RDBMSDataStore rDBMSDataStore) {
        HashSet<Predicate> hashSet = new HashSet();
        Iterator<Atom> it = set.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getPredicate());
        }
        HashMap hashMap = new HashMap();
        for (Predicate predicate : hashSet) {
            hashMap.put(predicate, rDBMSDataStore.getPredicateInfo(predicate).getTableStats(rDBMSDataStore.getDriver()));
        }
        return hashMap;
    }

    private Set<Atom> filterBaseAtoms(Set<Atom> set) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (Atom atom : set) {
            if (atom.getPredicate() instanceof ExternalFunctionalPredicate) {
                hashSet2.add(atom);
            } else if (atom.getPredicate() instanceof GroundingOnlyPredicate) {
                hashSet2.add(atom);
                hashSet.add(atom);
            } else if (!(atom.getPredicate() instanceof StandardPredicate)) {
                throw new IllegalStateException("Unknown predicate type: " + atom.getPredicate().getClass().getName());
            }
        }
        set.removeAll(hashSet2);
        return hashSet;
    }
}
