package org.linqs.psl.application.learning.weight.em;

import java.util.List;
import org.linqs.psl.application.learning.weight.VotedPerceptron;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/em/ExpectationMaximization.class */
public abstract class ExpectationMaximization extends VotedPerceptron {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ExpectationMaximization.class);
    public static final String CONFIG_PREFIX = "em";
    public static final String ITER_KEY = "em.iterations";
    public static final int ITER_DEFAULT = 10;
    public static final String TOLERANCE_KEY = "em.tolerance";
    public static final double TOLERANCE_DEFAULT = 0.001d;
    protected final int iterations;
    protected final double tolerance;
    protected int emIteration;

    public ExpectationMaximization(List<Rule> list, Database database, Database database2) {
        super(list, database, database2, true);
        this.iterations = Config.getInt(ITER_KEY, 10);
        this.tolerance = Config.getDouble(TOLERANCE_KEY, 0.001d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.VotedPerceptron, org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        double[] dArr = new double[this.mutableRules.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.mutableRules.get(i).getWeight();
        }
        this.emIteration = 0;
        while (this.emIteration < this.iterations) {
            log.debug("Beginning EM iteration {} of {}", Integer.valueOf(this.emIteration), Integer.valueOf(this.iterations));
            eStep();
            mStep();
            double d = 0.0d;
            for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
                d += Math.pow(dArr[i2] - this.mutableRules.get(i2).getWeight(), 2.0d);
                dArr[i2] = this.mutableRules.get(i2).getWeight();
            }
            double sqrt = Math.sqrt(d);
            double loss = getLoss();
            double computeRegularizer = computeRegularizer();
            log.info("Finished EM iteration {} with m-step norm {}. Loss: {}, regularizer: {}, objective: {}", Integer.valueOf(this.emIteration), Double.valueOf(sqrt), Double.valueOf(loss), Double.valueOf(computeRegularizer), Double.valueOf(loss + computeRegularizer));
            if (sqrt <= this.tolerance) {
                log.info("EM converged.");
                return;
            }
            this.emIteration++;
        }
    }

    protected void eStep() {
        computeLatentMPEState();
    }

    protected void mStep() {
        super.doLearn();
    }
}
