package org.linqs.psl.application.inference;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.grounding.Grounding;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.UnweightedRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.InitialValue;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.term.TermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Reflection;

/* loaded from: input_file:org/linqs/psl/application/inference/InferenceApplication.class */
public abstract class InferenceApplication implements ModelApplication {
    private static final Logger log = Logger.getLogger(InferenceApplication.class);
    protected List<Rule> rules;
    protected Database database;
    protected Reasoner reasoner;
    protected InitialValue initialValue;
    protected boolean skipInference;
    protected boolean normalizeWeights;
    protected boolean relaxHardConstraints;
    protected float relaxationMultiplier;
    protected boolean relaxationSquared;
    protected GroundRuleStore groundRuleStore;
    protected TermStore termStore;
    protected TermGenerator termGenerator;
    protected PersistedAtomManager atomManager;
    private boolean atomsCommitted;

    /* JADX INFO: Access modifiers changed from: protected */
    public InferenceApplication(List<Rule> list, Database database) {
        this(list, database, Options.INFERENCE_RELAX.getBoolean());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public InferenceApplication(List<Rule> list, Database database, boolean z) {
        this.rules = new ArrayList(list);
        this.database = database;
        this.atomsCommitted = false;
        this.initialValue = InitialValue.valueOf(Options.INFERENCE_INITIAL_VARIABLE_VALUE.getString());
        this.skipInference = Options.INFERENCE_SKIP_INFERENCE.getBoolean();
        this.normalizeWeights = Options.INFERENCE_NORMALIZE_WEIGHTS.getBoolean();
        this.relaxHardConstraints = z;
        this.relaxationMultiplier = Options.INFERENCE_RELAX_MULTIPLIER.getFloat();
        this.relaxationSquared = Options.INFERENCE_RELAX_SQUARED.getBoolean();
        initialize();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initialize() {
        log.debug("Creating persisted atom manager.");
        this.atomManager = createAtomManager(this.database);
        log.debug("Atom manager initialization complete.");
        initializeAtoms();
        if (this.normalizeWeights) {
            normalizeWeights();
        }
        if (this.relaxHardConstraints) {
            relaxHardConstraints();
        }
        this.reasoner = createReasoner();
        this.termGenerator = createTermGenerator();
        this.termStore = createTermStore();
        this.groundRuleStore = createGroundRuleStore();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        completeInitialize();
    }

    protected PersistedAtomManager createAtomManager(Database database) {
        return new PersistedAtomManager(database, false, this.initialValue);
    }

    protected GroundRuleStore createGroundRuleStore() {
        return (GroundRuleStore) Options.INFERENCE_GRS.getNewObject();
    }

    protected Reasoner createReasoner() {
        return (Reasoner) Options.INFERENCE_REASONER.getNewObject();
    }

    protected TermGenerator createTermGenerator() {
        return (TermGenerator) Options.INFERENCE_TG.getNewObject();
    }

    protected TermStore createTermStore() {
        return (TermStore) Options.INFERENCE_TS.getNewObject();
    }

    protected void completeInitialize() {
        log.info("Grounding out model.");
        boolean queryDBForClosedAtoms = this.atomManager.queryDBForClosedAtoms(false);
        long groundAll = Grounding.groundAll(this.rules, this.atomManager, this.groundRuleStore);
        this.atomManager.queryDBForClosedAtoms(queryDBForClosedAtoms);
        log.info("Grounding complete.");
        log.debug("Generated {} ground rules.", Long.valueOf(groundAll));
        if (this.skipInference) {
            return;
        }
        log.debug("Initializing objective terms for {} ground rules.", Long.valueOf(groundAll));
        log.debug("Generated {} objective terms from {} ground rules.", Long.valueOf(this.termGenerator.generateTerms(this.groundRuleStore, this.termStore)), Long.valueOf(groundAll));
    }

    public double inference() {
        return inference(true, false);
    }

    public double inference(boolean z, boolean z2) {
        return inference(z, z2, null, null);
    }

    public double inference(boolean z, boolean z2, List<Evaluator> list, Database database) {
        if (z2) {
            initializeAtoms();
            if (this.termStore != null) {
                this.termStore.reset();
            }
        }
        if (this.skipInference) {
            log.info("Skipping inference.");
            return -1.0d;
        }
        TrainingMap trainingMap = null;
        HashSet hashSet = null;
        if (database != null && list.size() > 0) {
            trainingMap = new TrainingMap(this.atomManager, database);
            hashSet = new HashSet();
            for (StandardPredicate standardPredicate : this.database.getDataStore().getRegisteredPredicates()) {
                if (database.countAllGroundAtoms(standardPredicate) > 0) {
                    hashSet.add(standardPredicate);
                }
            }
        }
        log.info("Beginning inference.");
        double internalInference = internalInference(list, trainingMap, hashSet);
        log.info("Inference complete.");
        this.atomsCommitted = false;
        if (z) {
            commit();
        }
        return internalInference;
    }

    protected double internalInference(List<Evaluator> list, TrainingMap trainingMap, Set<StandardPredicate> set) {
        return this.reasoner.optimize(this.termStore, list, trainingMap, set);
    }

    public Reasoner getReasoner() {
        return this.reasoner;
    }

    public GroundRuleStore getGroundRuleStore() {
        return this.groundRuleStore;
    }

    public TermStore getTermStore() {
        return this.termStore;
    }

    public PersistedAtomManager getAtomManager() {
        return this.atomManager;
    }

    public void setBudget(double d) {
        this.reasoner.setBudget(d);
    }

    public void initializeAtoms() {
        for (RandomVariableAtom randomVariableAtom : this.atomManager.getDatabase().getAllCachedRandomVariableAtoms()) {
            randomVariableAtom.setValue(this.initialValue.getVariableValue(randomVariableAtom));
        }
    }

    public void commit() {
        if (this.atomsCommitted) {
            return;
        }
        log.info("Writing results to Database.");
        this.atomManager.commitPersistedAtoms();
        log.info("Results committed to database.");
        this.atomsCommitted = true;
    }

    @Override // org.linqs.psl.application.ModelApplication
    public void close() {
        if (this.termStore != null) {
            this.termStore.close();
            this.termStore = null;
        }
        if (this.groundRuleStore != null) {
            this.groundRuleStore.close();
            this.groundRuleStore = null;
        }
        if (this.reasoner != null) {
            this.reasoner.close();
            this.reasoner = null;
        }
        this.rules = null;
        this.database = null;
    }

    protected void normalizeWeights() {
        float f = 0.0f;
        boolean z = false;
        Iterator it = IteratorUtils.filterClass(this.rules, WeightedRule.class).iterator();
        while (it.hasNext()) {
            float weight = ((WeightedRule) it.next()).getWeight();
            if (!z || weight > f) {
                f = weight;
                z = true;
            }
        }
        if (z) {
            for (WeightedRule weightedRule : IteratorUtils.filterClass(this.rules, WeightedRule.class)) {
                float weight2 = weightedRule.getWeight();
                float f2 = 1.0f;
                if (!MathUtils.isZero(f)) {
                    f2 = weight2 / f;
                }
                log.debug("Normalizing rule weight (old weight: {}, new weight: {}): {}", Float.valueOf(weight2), Float.valueOf(f2), weightedRule);
                weightedRule.setWeight(f2);
            }
        }
    }

    protected void relaxHardConstraints() {
        float f = 0.0f;
        boolean z = false;
        for (Rule rule : this.rules) {
            if (rule instanceof WeightedRule) {
                float weight = ((WeightedRule) rule).getWeight();
                if (weight > f) {
                    f = weight;
                }
            } else {
                z = true;
            }
        }
        if (z) {
            float max = Math.max(1.0f, f * this.relaxationMultiplier);
            for (int i = 0; i < this.rules.size(); i++) {
                if (this.rules.get(i) instanceof UnweightedRule) {
                    log.debug("Relaxing hard constraint (weight: {}, squared: {}): {}", Float.valueOf(max), Boolean.valueOf(this.relaxationSquared), this.rules.get(i));
                    this.rules.set(i, ((UnweightedRule) this.rules.get(i)).relax(max, this.relaxationSquared));
                }
            }
        }
    }

    public static InferenceApplication getInferenceApplication(String str, List<Rule> list, Database database) {
        String resolveClassName = Reflection.resolveClassName(str);
        try {
            try {
                try {
                    return (InferenceApplication) Class.forName(resolveClassName).getConstructor(List.class, Database.class).newInstance(list, database);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException("Insufficient access to constructor for " + resolveClassName, e);
                } catch (InstantiationException e2) {
                    throw new RuntimeException("Unable to instantiate inference application (" + resolveClassName + ")", e2);
                } catch (InvocationTargetException e3) {
                    throw new RuntimeException("Error thrown while constructing " + resolveClassName, e3);
                }
            } catch (NoSuchMethodException e4) {
                throw new IllegalArgumentException("No suitable constructor (List<Rules>, Database) found for inference application: " + resolveClassName + ".", e4);
            }
        } catch (ClassNotFoundException e5) {
            throw new IllegalArgumentException("Could not find class: " + resolveClassName, e5);
        }
    }
}
