package org.linqs.psl.database.atom;

import com.healthmarketscience.sqlbuilder.SelectQuery;
import com.healthmarketscience.sqlbuilder.SetOperationQuery;
import com.healthmarketscience.sqlbuilder.UnionQuery;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.ResultList;
import org.linqs.psl.database.rdbms.Formula2SQL;
import org.linqs.psl.database.rdbms.RDBMSDatabase;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.arithmetic.AbstractArithmeticRule;
import org.linqs.psl.model.rule.logical.AbstractLogicalRule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.model.term.VariableTypeMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/database/atom/LazyAtomManager.class */
public class LazyAtomManager extends PersistedAtomManager {
    private static final Logger log = LoggerFactory.getLogger(LazyAtomManager.class);
    private final Set<RandomVariableAtom> lazyAtoms;
    private final double activation;

    public LazyAtomManager(Database database) {
        super(database);
        if (!(database instanceof RDBMSDatabase)) {
            throw new IllegalArgumentException("LazyAtomManagers require RDBMSDatabase.");
        }
        this.lazyAtoms = new HashSet();
        this.activation = Options.LAM_ACTIVATION_THRESHOLD.getDouble();
    }

    @Override // org.linqs.psl.database.atom.PersistedAtomManager, org.linqs.psl.database.atom.AtomManager
    public synchronized GroundAtom getAtom(Predicate predicate, Constant... constantArr) {
        GroundAtom atom = this.db.getAtom(predicate, constantArr);
        if (!(atom instanceof RandomVariableAtom)) {
            return atom;
        }
        RandomVariableAtom randomVariableAtom = (RandomVariableAtom) atom;
        if (!randomVariableAtom.getPersisted()) {
            this.lazyAtoms.add(randomVariableAtom);
        }
        return randomVariableAtom;
    }

    @Override // org.linqs.psl.database.atom.PersistedAtomManager, org.linqs.psl.database.atom.AtomManager
    public void reportAccessException(RuntimeException runtimeException, GroundAtom groundAtom) {
    }

    public Set<RandomVariableAtom> getLazyAtoms() {
        return Collections.unmodifiableSet(this.lazyAtoms);
    }

    public int countActivatableAtoms() {
        int i = 0;
        Iterator<RandomVariableAtom> it = this.lazyAtoms.iterator();
        while (it.hasNext()) {
            if (it.next().getValue() >= this.activation) {
                i++;
            }
        }
        return i;
    }

    public int activateAtoms(List<Rule> list, GroundRuleStore groundRuleStore) {
        if (this.lazyAtoms.size() == 0) {
            return 0;
        }
        HashSet hashSet = new HashSet();
        Iterator<RandomVariableAtom> it = this.lazyAtoms.iterator();
        while (it.hasNext()) {
            RandomVariableAtom next = it.next();
            if (next.getValue() >= this.activation) {
                hashSet.add(next);
                it.remove();
            }
        }
        activate(hashSet, list, groundRuleStore);
        return hashSet.size();
    }

    public int activateAtoms(Set<RandomVariableAtom> set, List<Rule> list, GroundRuleStore groundRuleStore) {
        Iterator<RandomVariableAtom> it = set.iterator();
        while (it.hasNext()) {
            if (!this.lazyAtoms.contains(it.next())) {
                it.remove();
            }
        }
        activate(set, list, groundRuleStore);
        return set.size();
    }

    private void activate(Set<RandomVariableAtom> set, List<Rule> list, GroundRuleStore groundRuleStore) {
        this.db.commit(set, -1);
        addToPersistedCache(set);
        Set<StandardPredicate> lazyPredicates = getLazyPredicates(set);
        Set<Rule> lazyRules = getLazyRules(list, lazyPredicates);
        for (Rule rule : lazyRules) {
            if (rule.supportsGroundingQueryRewriting()) {
                lazySimpleGround(rule, lazyPredicates, groundRuleStore);
            }
        }
        Iterator<StandardPredicate> it = lazyPredicates.iterator();
        while (it.hasNext()) {
            this.db.moveToWritePartition(it.next(), -1);
        }
        for (Rule rule2 : lazyRules) {
            if (!rule2.supportsGroundingQueryRewriting()) {
                lazyComplexGround((AbstractArithmeticRule) rule2, groundRuleStore);
            }
        }
    }

    private void lazyComplexGround(AbstractArithmeticRule abstractArithmeticRule, GroundRuleStore groundRuleStore) {
        log.trace(String.format("Complex lazy grounding on rule [%s]", abstractArithmeticRule));
        groundRuleStore.removeGroundRules(abstractArithmeticRule);
        abstractArithmeticRule.groundAll(this, groundRuleStore);
    }

    private void lazySimpleGround(Rule rule, Set<StandardPredicate> set, GroundRuleStore groundRuleStore) {
        if (!rule.supportsGroundingQueryRewriting()) {
            throw new UnsupportedOperationException("Rule requires full regrounding: " + rule);
        }
        Formula rewritableGroundingFormula = rule.getRewritableGroundingFormula(this);
        ResultList lazyGroundingResults = getLazyGroundingResults(rewritableGroundingFormula, set);
        if (lazyGroundingResults == null) {
            return;
        }
        log.trace(String.format("Simple lazy grounding on rule: [%s], formula: [%s]", rule, rewritableGroundingFormula));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < lazyGroundingResults.size(); i++) {
            arrayList.clear();
            rule.ground(lazyGroundingResults.get(i), lazyGroundingResults.getVariableMap(), this, arrayList);
            for (GroundRule groundRule : arrayList) {
                if (groundRule != null) {
                    groundRuleStore.addGroundRule(groundRule);
                }
            }
        }
    }

    private ResultList getLazyGroundingResults(Formula formula, Set<StandardPredicate> set) {
        ArrayList arrayList = new ArrayList();
        for (Atom atom : formula.getAtoms(new HashSet())) {
            if (set.contains(atom.getPredicate())) {
                arrayList.add(atom);
            }
        }
        if (arrayList.size() == 0) {
            return null;
        }
        return lazyGround(formula, arrayList);
    }

    private ResultList lazyGround(Formula formula, List<Atom> list) {
        if (list.size() == 0) {
            throw new IllegalArgumentException();
        }
        RDBMSDatabase rDBMSDatabase = (RDBMSDatabase) this.db;
        ArrayList arrayList = new ArrayList();
        VariableTypeMap collectVariables = formula.collectVariables(new VariableTypeMap());
        Map<Variable, Integer> map = null;
        Iterator<Atom> it = list.iterator();
        while (it.hasNext()) {
            Formula2SQL formula2SQL = new Formula2SQL(collectVariables.getVariables(), rDBMSDatabase, false, it.next());
            arrayList.add(formula2SQL.getQuery(formula));
            if (map == null) {
                map = formula2SQL.getProjectionMap();
            }
        }
        return rDBMSDatabase.executeQuery(map, collectVariables, new UnionQuery(SetOperationQuery.Type.UNION, (SelectQuery[]) arrayList.toArray(new SelectQuery[0])).validate().toString());
    }

    private Set<StandardPredicate> getLazyPredicates(Set<RandomVariableAtom> set) {
        HashSet hashSet = new HashSet();
        for (RandomVariableAtom randomVariableAtom : set) {
            if (randomVariableAtom.getPredicate() instanceof StandardPredicate) {
                hashSet.add((StandardPredicate) randomVariableAtom.getPredicate());
            }
        }
        return hashSet;
    }

    private Set<Rule> getLazyRules(List<Rule> list, Set<StandardPredicate> set) {
        HashSet hashSet = new HashSet();
        for (Rule rule : list) {
            if (!(rule instanceof AbstractLogicalRule)) {
                if (!(rule instanceof AbstractArithmeticRule)) {
                    throw new IllegalStateException("Unknown rule type: " + rule.getClass().getName());
                }
                Iterator<Predicate> it = ((AbstractArithmeticRule) rule).getBodyPredicates().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    if (set.contains(it.next())) {
                        hashSet.add(rule);
                        break;
                    }
                }
            } else {
                Iterator<Atom> it2 = ((AbstractLogicalRule) rule).getNegatedDNF().getQueryFormula().getAtoms(new HashSet()).iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    if (set.contains(it2.next().getPredicate())) {
                        hashSet.add(rule);
                        break;
                    }
                }
            }
        }
        return hashSet;
    }
}
