package org.linqs.psl.grounding.collective;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.DatabaseQuery;
import org.linqs.psl.database.rdbms.Formula2SQL;
import org.linqs.psl.database.rdbms.RDBMSDataStore;
import org.linqs.psl.database.rdbms.RDBMSDatabase;
import org.linqs.psl.database.rdbms.driver.DatabaseDriver;
import org.linqs.psl.grounding.collective.SearchFringe;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.formula.Conjunction;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.predicate.ExternalFunctionalPredicate;
import org.linqs.psl.model.predicate.GroundingOnlyPredicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.util.BitUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/grounding/collective/CandidateGeneration.class */
public class CandidateGeneration {
    private static final Logger log;
    public static final double CANDIDATE_SIZE_ADJUSTMENT = 1.0d;
    public static final double OPTIMISTIC_QUERY_COST_MULTIPLIER = 0.018d;
    public static final double OPTIMISTIC_INSTANTIATION_COST_MULTIPLIER = 0.001d;
    public static final double PESSIMISTIC_QUERY_COST_MULTIPLIER = 0.02d;
    public static final double PESSIMISTIC_INSTANTIATION_COST_MULTIPLIER = 0.002d;
    private SearchType searchType = SearchType.valueOf(Options.GROUNDING_COLLECTIVE_CANDIDATE_SEARCH_TYPE.getString());
    private int budget = Options.GROUNDING_COLLECTIVE_CANDIDATE_SEARCH_BUDGET.getInt();
    private Map<String, DatabaseDriver.ExplainResult> explains = new ConcurrentHashMap();
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/linqs/psl/grounding/collective/CandidateGeneration$SearchType.class */
    public enum SearchType {
        BFS,
        DFS,
        UCS,
        BoundedUCS,
        BoundedDFS
    }

    public void generateCandidates(Rule rule, RDBMSDatabase rDBMSDatabase, int i, Collection<CandidateQuery> collection) {
        List<CandidateQuery> search = search(createFringe(), rule, rDBMSDatabase);
        Collections.sort(search);
        for (int i2 = 0; i2 < Math.min(search.size(), i); i2++) {
            collection.add(search.get(i2));
        }
    }

    private List<CandidateQuery> search(SearchFringe searchFringe, Rule rule, RDBMSDatabase rDBMSDatabase) {
        searchFringe.clear();
        Formula rewritableGroundingFormula = rule.getRewritableGroundingFormula();
        DatabaseQuery.validate(rewritableGroundingFormula);
        if (rewritableGroundingFormula instanceof Atom) {
            return singleAtomSearch(rule, rewritableGroundingFormula, rDBMSDatabase);
        }
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        rewritableGroundingFormula.getAtoms(hashSet);
        Set<Atom> filterSpecialAtoms = filterSpecialAtoms(hashSet);
        Map<Variable, Set<Atom>> allUsedVariables = getAllUsedVariables(hashSet);
        ArrayList arrayList2 = new ArrayList(hashSet);
        Collections.sort(arrayList2, new Comparator<Atom>() { // from class: org.linqs.psl.grounding.collective.CandidateGeneration.1
            @Override // java.util.Comparator
            public int compare(Atom atom, Atom atom2) {
                return atom.toString().compareTo(atom2.toString());
            }
        });
        hashSet.clear();
        HashSet hashSet2 = new HashSet();
        boolean[] zArr = new boolean[arrayList2.size()];
        for (int i = 0; i < zArr.length; i++) {
            zArr[i] = true;
        }
        searchFringe.push(validateAndCreateNode(zArr, arrayList2, filterSpecialAtoms, allUsedVariables, hashSet, 0.0d, 0.0d));
        hashSet2.add(Long.valueOf(BitUtils.toBitSet(zArr)));
        int i2 = 0;
        while (searchFringe.size() > 0 && (this.budget <= 0 || i2 < this.budget)) {
            CandidateSearchNode pop = searchFringe.pop();
            BitUtils.toBits(pop.atomsBitSet, zArr);
            if (explainNode(pop, rDBMSDatabase)) {
                i2++;
            }
            arrayList.add(new CandidateQuery(rule, pop.formula, pop.optimisticCost));
            searchFringe.newPessimisticCost(pop.pessimisticCost);
            for (int i3 = 0; i3 < arrayList2.size(); i3++) {
                if (zArr[i3]) {
                    zArr[i3] = false;
                    Long valueOf = Long.valueOf(BitUtils.toBitSet(zArr));
                    if (!hashSet2.contains(valueOf) && valueOf.longValue() != 0) {
                        hashSet2.add(valueOf);
                        CandidateSearchNode validateAndCreateNode = validateAndCreateNode(zArr, arrayList2, filterSpecialAtoms, allUsedVariables, hashSet, pop.optimisticCost, pop.pessimisticCost);
                        if (validateAndCreateNode != null) {
                            searchFringe.push(validateAndCreateNode);
                        }
                    }
                    zArr[i3] = true;
                }
            }
        }
        return arrayList;
    }

    private List<CandidateQuery> singleAtomSearch(Rule rule, Formula formula, RDBMSDatabase rDBMSDatabase) {
        if (!$assertionsDisabled && !(formula instanceof Atom)) {
            throw new AssertionError();
        }
        ArrayList arrayList = new ArrayList(2);
        CandidateSearchNode candidateSearchNode = new CandidateSearchNode(0L, formula, 1, 1.0d, 1.0d);
        explainNode(candidateSearchNode, rDBMSDatabase);
        arrayList.add(new CandidateQuery(rule, candidateSearchNode.formula, candidateSearchNode.optimisticCost));
        HashSet hashSet = new HashSet();
        rule.getCoreAtoms(hashSet);
        if (hashSet.size() != 2) {
            return arrayList;
        }
        int i = 0;
        for (Atom atom : hashSet) {
            if ((atom.getPredicate() instanceof StandardPredicate) && !rDBMSDatabase.isClosed((StandardPredicate) atom.getPredicate())) {
                i++;
            }
        }
        if (i == 0) {
            return arrayList;
        }
        hashSet.remove(formula);
        if (hashSet.size() != 1) {
            return arrayList;
        }
        CandidateSearchNode candidateSearchNode2 = new CandidateSearchNode(0L, hashSet.iterator().next(), 1, 1.0d, 1.0d);
        explainNode(candidateSearchNode2, rDBMSDatabase);
        arrayList.add(new CandidateQuery(rule, candidateSearchNode2.formula, candidateSearchNode2.optimisticCost));
        return arrayList;
    }

    private CandidateSearchNode validateAndCreateNode(boolean[] zArr, List<Atom> list, Set<Atom> set, Map<Variable, Set<Atom>> map, Set<Atom> set2, double d, double d2) {
        Formula constructFormula = constructFormula(zArr, list, set, set2);
        Iterator<Map.Entry<Variable, Set<Atom>>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            boolean z = false;
            Iterator<Atom> it2 = it.next().getValue().iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                if (set2.contains(it2.next())) {
                    z = true;
                    break;
                }
            }
            if (!z) {
                set2.clear();
                return null;
            }
        }
        int size = set2.size();
        set2.clear();
        if (!$assertionsDisabled && size <= 0) {
            throw new AssertionError();
        }
        return new CandidateSearchNode(BitUtils.toBitSet(zArr), constructFormula, size, (d * size) / (size + 1), (d2 * size) / (size + 1));
    }

    private boolean explainNode(CandidateSearchNode candidateSearchNode, RDBMSDatabase rDBMSDatabase) {
        DatabaseDriver.ExplainResult explain;
        boolean z = false;
        String obj = candidateSearchNode.formula.toString();
        if (this.explains.containsKey(obj)) {
            explain = this.explains.get(obj);
        } else {
            explain = ((RDBMSDataStore) rDBMSDatabase.getDataStore()).getDriver().explain(Formula2SQL.getQuery(candidateSearchNode.formula, rDBMSDatabase, false));
            this.explains.put(obj, explain);
            z = true;
        }
        candidateSearchNode.approximateCost = false;
        candidateSearchNode.optimisticCost = ((explain.totalCost * 0.018d) + (explain.rows * 0.001d)) * candidateSearchNode.numAtoms * 1.0d;
        candidateSearchNode.pessimisticCost = ((explain.totalCost * 0.02d) + (explain.rows * 0.002d)) * candidateSearchNode.numAtoms * 1.0d;
        if (z) {
            log.trace("Scored candidate: " + candidateSearchNode);
        }
        return z;
    }

    private Formula constructFormula(boolean[] zArr, List<Atom> list, Set<Atom> set, Set<Atom> set2) {
        if (!$assertionsDisabled && !set2.isEmpty()) {
            throw new AssertionError();
        }
        set2.addAll(set);
        for (int i = 0; i < zArr.length; i++) {
            if (zArr[i]) {
                set2.add(list.get(i));
            }
        }
        return set2.size() == 1 ? set2.iterator().next() : new Conjunction((Formula[]) set2.toArray(new Formula[0]));
    }

    private Map<Variable, Set<Atom>> getAllUsedVariables(Set<Atom> set) {
        HashMap hashMap = new HashMap();
        for (Atom atom : set) {
            if (atom.getPredicate() instanceof StandardPredicate) {
                for (Variable variable : atom.getVariables()) {
                    if (!hashMap.containsKey(variable)) {
                        hashMap.put(variable, new HashSet());
                    }
                    ((Set) hashMap.get(variable)).add(atom);
                }
            }
        }
        return hashMap;
    }

    private Set<Atom> filterSpecialAtoms(Set<Atom> set) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (Atom atom : set) {
            if (atom.getPredicate() instanceof ExternalFunctionalPredicate) {
                hashSet2.add(atom);
            } else if (!(atom.getPredicate() instanceof GroundingOnlyPredicate) && !(atom.getPredicate() instanceof StandardPredicate)) {
                throw new IllegalStateException("Unknown predicate type: " + atom.getPredicate().getClass().getName());
            }
        }
        set.removeAll(hashSet2);
        return hashSet;
    }

    private SearchFringe createFringe() {
        switch (this.searchType) {
            case BFS:
                return new SearchFringe.BFSSearchFringe();
            case DFS:
                return new SearchFringe.DFSSearchFringe();
            case UCS:
                return new SearchFringe.UCSSearchFringe();
            case BoundedUCS:
                return new SearchFringe.BoundedUCSSearchFringe();
            case BoundedDFS:
                return new SearchFringe.BoundedDFSSearchFringe();
            default:
                throw new IllegalStateException("Unknown search type: " + this.searchType);
        }
    }

    static {
        $assertionsDisabled = !CandidateGeneration.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger((Class<?>) CandidateGeneration.class);
    }
}
