package org.linqs.psl.application.learning.weight;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.Reflection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/WeightLearningApplication.class */
public abstract class WeightLearningApplication implements ModelApplication {
    private static final Logger log = LoggerFactory.getLogger(WeightLearningApplication.class);
    protected Database rvDB;
    protected Database observedDB;
    protected List<Rule> allRules = new ArrayList();
    protected List<WeightedRule> mutableRules = new ArrayList();
    protected TrainingMap trainingMap;
    protected InferenceApplication inference;
    protected Evaluator evaluator;
    private boolean groundModelInit;
    protected boolean inMPEState;

    public WeightLearningApplication(List<Rule> list, Database database, Database database2) {
        this.rvDB = database;
        this.observedDB = database2;
        for (Rule rule : list) {
            this.allRules.add(rule);
            if (rule instanceof WeightedRule) {
                this.mutableRules.add((WeightedRule) rule);
            }
        }
        this.groundModelInit = false;
        this.inMPEState = false;
        this.evaluator = (Evaluator) Options.WLA_EVAL.getNewObject();
    }

    public void learn() {
        initGroundModel();
        doLearn();
    }

    protected abstract void doLearn();

    public void setBudget(double d) {
        this.inference.setBudget(d);
    }

    public InferenceApplication getInferenceApplication() {
        return this.inference;
    }

    protected void initGroundModel() {
        if (this.groundModelInit) {
            return;
        }
        initGroundModel(InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), this.allRules, this.rvDB));
    }

    private void initGroundModel(InferenceApplication inferenceApplication) {
        if (this.groundModelInit) {
            return;
        }
        initGroundModel(inferenceApplication, new TrainingMap(inferenceApplication.getAtomManager(), this.observedDB));
    }

    public void initGroundModel(InferenceApplication inferenceApplication, TrainingMap trainingMap) {
        if (this.groundModelInit) {
            return;
        }
        this.inference = inferenceApplication;
        this.trainingMap = trainingMap;
        if (Options.WLA_RANDOM_WEIGHTS.getBoolean()) {
            initRandomWeights();
        }
        postInitGroundModel();
        this.groundModelInit = true;
    }

    private void initRandomWeights() {
        log.trace("Randomly Weighted Rules:");
        for (WeightedRule weightedRule : this.mutableRules) {
            weightedRule.setWeight(RandUtils.nextFloat());
            log.trace("    " + weightedRule.toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void postInitGroundModel() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeMPEState() {
        if (this.inMPEState) {
            return;
        }
        this.inference.inference(false, true);
        this.inMPEState = true;
    }

    @Override // org.linqs.psl.application.ModelApplication
    public void close() {
        if (this.inference != null) {
            this.inference.commit();
            this.inference.close();
            this.inference = null;
        }
        this.trainingMap = null;
        this.rvDB = null;
        this.observedDB = null;
    }

    public static WeightLearningApplication getWLA(String str, List<Rule> list, Database database, Database database2) {
        String resolveClassName = Reflection.resolveClassName(str);
        if (resolveClassName == null) {
            throw new IllegalArgumentException("Could not find class: " + str);
        }
        try {
            try {
                try {
                    return (WeightLearningApplication) Class.forName(resolveClassName).getConstructor(List.class, Database.class, Database.class).newInstance(list, database, database2);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException("Insufficient access to constructor for " + resolveClassName, e);
                } catch (InstantiationException e2) {
                    throw new RuntimeException("Unable to instantiate weight learner (" + resolveClassName + ")", e2);
                } catch (InvocationTargetException e3) {
                    throw new RuntimeException("Error thrown while constructing " + resolveClassName, e3);
                }
            } catch (NoSuchMethodException e4) {
                throw new IllegalArgumentException("No sutible constructor found for weight learner: " + resolveClassName + ".", e4);
            }
        } catch (ClassNotFoundException e5) {
            throw new IllegalArgumentException("Could not find class: " + resolveClassName, e5);
        }
    }
}
