package org.linqs.psl.application.learning.weight;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.Reflection;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/WeightLearningApplication.class */
public abstract class WeightLearningApplication implements ModelApplication {
    public static final String DELIM = ":";
    private static final Logger log = Logger.getLogger(WeightLearningApplication.class);
    protected Database trainTargetDatabase;
    protected Database trainTruthDatabase;
    protected Database validationTargetDatabase;
    protected Database validationTruthDatabase;
    protected boolean runValidation;
    protected TrainingMap trainingMap;
    protected TrainingMap validationMap;
    protected InferenceApplication trainInferenceApplication;
    protected InferenceApplication validationInferenceApplication;
    protected EvaluationInstance evaluation;
    private boolean groundModelInit;
    protected boolean inTrainingMAPState;
    protected boolean inValidationMAPState;
    protected List<DeepPredicate> deepPredicates = new ArrayList();
    protected List<DeepModelPredicate> deepModelPredicates = new ArrayList();
    protected List<DeepModelPredicate> validationDeepModelPredicates = new ArrayList();
    protected List<Rule> allRules = new ArrayList();
    protected List<WeightedRule> mutableRules = new ArrayList();

    public WeightLearningApplication(List<Rule> list, Database database, Database database2, Database database3, Database database4, Boolean bool) {
        this.trainTargetDatabase = database;
        this.trainTruthDatabase = database2;
        this.validationTargetDatabase = database3;
        this.validationTruthDatabase = database4;
        this.runValidation = bool.booleanValue();
        for (Rule rule : list) {
            this.allRules.add(rule);
            if (rule instanceof WeightedRule) {
                this.mutableRules.add((WeightedRule) rule);
            }
        }
        this.trainInferenceApplication = null;
        this.validationInferenceApplication = null;
        this.trainingMap = null;
        this.validationMap = null;
        this.groundModelInit = false;
        this.inTrainingMAPState = false;
        this.inValidationMAPState = false;
        this.evaluation = null;
    }

    public void setEvaluation(EvaluationInstance evaluationInstance) {
        this.evaluation = evaluationInstance;
    }

    public void learn() {
        initGroundModel();
        doLearn();
    }

    protected abstract void doLearn();

    public void setBudget(double d) {
        this.trainInferenceApplication.setBudget(d);
    }

    protected void initGroundModel() {
        if (this.groundModelInit) {
            return;
        }
        InferenceApplication inferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), this.allRules, this.trainTargetDatabase);
        inferenceApplication.loadDeepPredicates("learning");
        initGroundModel(inferenceApplication, InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), this.allRules, this.validationTargetDatabase));
    }

    private void initGroundModel(InferenceApplication inferenceApplication, InferenceApplication inferenceApplication2) {
        if (this.groundModelInit) {
            return;
        }
        initGroundModel(inferenceApplication, new TrainingMap(inferenceApplication.getDatabase(), this.trainTruthDatabase), inferenceApplication2, new TrainingMap(inferenceApplication2.getDatabase(), this.validationTruthDatabase));
    }

    public void initGroundModel(InferenceApplication inferenceApplication, TrainingMap trainingMap, InferenceApplication inferenceApplication2, TrainingMap trainingMap2) {
        if (this.groundModelInit) {
            return;
        }
        this.trainInferenceApplication = inferenceApplication;
        this.trainingMap = trainingMap;
        this.validationInferenceApplication = inferenceApplication2;
        this.validationMap = trainingMap2;
        if (Options.WLA_RANDOM_WEIGHTS.getBoolean()) {
            initRandomWeights();
        }
        for (Predicate predicate : Predicate.getAll()) {
            if (predicate instanceof DeepPredicate) {
                this.deepPredicates.add((DeepPredicate) predicate);
                this.deepModelPredicates.add(((DeepPredicate) predicate).getDeepModel());
                DeepModelPredicate copy = ((DeepPredicate) predicate).getDeepModel().copy();
                copy.setAtomStore(inferenceApplication2.getDatabase().getAtomStore(), true);
                this.validationDeepModelPredicates.add(copy);
            }
        }
        postInitGroundModel();
        this.groundModelInit = true;
    }

    private void initRandomWeights() {
        log.trace("Randomly Weighted Rules:");
        for (WeightedRule weightedRule : this.mutableRules) {
            weightedRule.setWeight(RandUtils.nextFloat());
            log.trace("    " + weightedRule.toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void postInitGroundModel() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeTrainingMAPState() {
        if (this.inTrainingMAPState) {
            return;
        }
        computeMAPState(this.trainInferenceApplication);
        this.inTrainingMAPState = true;
    }

    protected void computeValidationMAPState() {
        if (this.inValidationMAPState) {
            return;
        }
        computeMAPState(this.validationInferenceApplication);
        this.inValidationMAPState = true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeMAPState(InferenceApplication inferenceApplication) {
        inferenceApplication.inference(false, true);
    }

    @Override // org.linqs.psl.application.ModelApplication
    public void close() {
        if (this.trainInferenceApplication != null) {
            this.trainInferenceApplication.commit();
            this.trainInferenceApplication.close();
            this.trainInferenceApplication = null;
        }
        if (this.validationInferenceApplication != null) {
            this.validationInferenceApplication.commit();
            this.validationInferenceApplication.close();
            this.validationInferenceApplication = null;
        }
        this.trainingMap = null;
        this.trainTargetDatabase = null;
        this.trainTruthDatabase = null;
        this.validationMap = null;
        this.validationTargetDatabase = null;
        this.validationTruthDatabase = null;
    }

    public static WeightLearningApplication getWLA(String str, List<Rule> list, Database database, Database database2, Database database3, Database database4, boolean z) {
        String resolveClassName = Reflection.resolveClassName(str);
        if (resolveClassName == null) {
            throw new IllegalArgumentException("Could not find class: " + str);
        }
        try {
            try {
                try {
                    return (WeightLearningApplication) Class.forName(resolveClassName).getConstructor(List.class, Database.class, Database.class, Database.class, Database.class, Boolean.TYPE).newInstance(list, database, database2, database3, database4, Boolean.valueOf(z));
                } catch (IllegalAccessException e) {
                    throw new RuntimeException("Insufficient access to constructor for " + resolveClassName, e);
                } catch (InstantiationException e2) {
                    throw new RuntimeException("Unable to instantiate weight learner (" + resolveClassName + ")", e2);
                } catch (InvocationTargetException e3) {
                    throw new RuntimeException("Error thrown while constructing " + resolveClassName, e3);
                }
            } catch (NoSuchMethodException e4) {
                throw new IllegalArgumentException("No suitable constructor found for weight learner: " + resolveClassName + ".", e4);
            }
        } catch (ClassNotFoundException e5) {
            throw new IllegalArgumentException("Could not find class: " + resolveClassName, e5);
        }
    }
}
