package org.linqs.psl.application.learning.weight.em;

import java.util.List;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.admm.ADMMReasoner;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/em/PairedDualLearner.class */
public class PairedDualLearner extends ExpectationMaximization {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) PairedDualLearner.class);
    public static final String CONFIG_PREFIX = "pairedduallearner";
    public static final String WARMUP_ROUNDS_KEY = "pairedduallearner.warmuprounds";
    public static final int WARMUP_ROUNDS_DEFAULT = 0;
    public static final String ADMM_STEPS_KEY = "pairedduallearner.admmsteps";
    public static final int ADMM_STEPS_DEFAULT = 1;
    private final int warmupRounds;
    private final int admmIterations;

    public PairedDualLearner(Model model, Database database, Database database2) {
        this(model.getRules(), database, database2);
    }

    public PairedDualLearner(List<Rule> list, Database database, Database database2) {
        super(list, database, database2);
        this.warmupRounds = Config.getInt(WARMUP_ROUNDS_KEY, 0);
        if (this.warmupRounds < 0) {
            throw new IllegalArgumentException("pairedduallearner.warmuprounds must be a nonnegative integer.");
        }
        this.admmIterations = Config.getInt(ADMM_STEPS_KEY, 1);
        if (this.admmIterations < 1) {
            throw new IllegalArgumentException("pairedduallearner.admmsteps must be a positive integer.");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void computeExpectedIncompatibility() {
        computeMPEState();
        for (int i = 0; i < this.expectedIncompatibility.length; i++) {
            this.expectedIncompatibility[i] = 0.0d;
        }
        ADMMReasoner aDMMReasoner = (ADMMReasoner) this.reasoner;
        float[] fArr = new float[((ADMMTermStore) this.termStore).getNumGlobalVariables()];
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules(this.mutableRules.get(i2))) {
                double[] dArr = this.expectedIncompatibility;
                int i3 = i2;
                dArr[i3] = dArr[i3] + aDMMReasoner.getDualIncompatibility(groundRule, (ADMMTermStore) this.termStore, fArr);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void computeObservedIncompatibility() {
        setLabeledRandomVariables();
        computeLatentMPEState();
        for (int i = 0; i < this.observedIncompatibility.length; i++) {
            this.observedIncompatibility[i] = 0.0d;
        }
        ADMMReasoner aDMMReasoner = (ADMMReasoner) this.reasoner;
        float[] fArr = new float[((ADMMTermStore) this.latentTermStore).getNumGlobalVariables()];
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            for (GroundRule groundRule : this.latentGroundRuleStore.getGroundRules(this.mutableRules.get(i2))) {
                double[] dArr = this.observedIncompatibility;
                int i3 = i2;
                dArr[i3] = dArr[i3] + aDMMReasoner.getDualIncompatibility(groundRule, (ADMMTermStore) this.latentTermStore, fArr);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.em.ExpectationMaximization, org.linqs.psl.application.learning.weight.VotedPerceptron, org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        if (!(this.reasoner instanceof ADMMReasoner)) {
            throw new IllegalArgumentException(String.format("PairedDualLearning can only be used with ADMMReasoner, found %s.", this.reasoner.getClass().getName()));
        }
        if (!(this.termStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException(String.format("PairedDualLearning can only be used with ADMMTermStore, found %s.", this.termStore.getClass().getName()));
        }
        if (!(this.latentTermStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException(String.format("PairedDualLearning (latent) can only be used with ADMMTermStore, found %s.", this.latentTermStore.getClass().getName()));
        }
        ADMMReasoner aDMMReasoner = (ADMMReasoner) this.reasoner;
        int maxIter = aDMMReasoner.getMaxIter();
        aDMMReasoner.setMaxIter(this.admmIterations);
        if (this.warmupRounds > 0) {
            log.debug("Warming up optimizer with {} iterations.", Integer.valueOf(this.warmupRounds * this.admmIterations));
            for (int i = 0; i < this.warmupRounds; i++) {
                this.reasoner.optimize(this.termStore);
                this.reasoner.optimize(this.latentTermStore);
            }
        }
        subgrad();
        aDMMReasoner.setMaxIter(maxIter);
    }

    private void subgrad() {
        log.info("Starting optimization");
        double[] dArr = new double[this.mutableRules.size()];
        for (int i = 0; i < this.mutableRules.size(); i++) {
            dArr[i] = this.mutableRules.get(i).getWeight();
        }
        double[] dArr2 = new double[this.mutableRules.size()];
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            dArr2[i2] = 1.0d;
        }
        double[] dArr3 = new double[this.mutableRules.size()];
        double d = 0.0d;
        this.emIteration = 0;
        while (true) {
            if (this.emIteration >= this.iterations) {
                break;
            }
            d = getValueAndGradient(dArr2, dArr);
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
                d2 += Math.pow(dArr[i3] - Math.max(0.0d, dArr[i3] - dArr2[i3]), 2.0d);
                double max = Math.max(-dArr[i3], (-this.baseStepSize) * dArr2[i3]);
                int i4 = i3;
                dArr[i4] = dArr[i4] + max;
                dArr2[i3] = max;
                d3 += Math.pow(max, 2.0d);
                dArr3[i3] = ((1.0d - (1.0d / (this.emIteration + 1.0d))) * dArr3[i3]) + ((1.0d / (this.emIteration + 1.0d)) * dArr[i3]);
            }
            double sqrt = Math.sqrt(d2);
            double sqrt2 = Math.sqrt(d3);
            log.debug("Iter {}, obj: {}, norm grad: {}, change: {}", Integer.valueOf(this.emIteration), Double.valueOf(d), Double.valueOf(sqrt), Double.valueOf(sqrt2));
            if (sqrt2 < this.tolerance) {
                log.info("Change in w ({}) is less than tolerance. Finishing subgrad.", Double.valueOf(sqrt2));
                break;
            }
            this.emIteration++;
        }
        log.info("Learning finished with final objective value {}", Double.valueOf(d));
        for (int i5 = 0; i5 < this.mutableRules.size(); i5++) {
            if (this.averageSteps) {
                dArr[i5] = dArr3[i5];
            }
            this.mutableRules.get(i5).setWeight(dArr[i5]);
        }
        this.inMPEState = false;
        this.inLatentMPEState = false;
    }

    private double getValueAndGradient(double[] dArr, double[] dArr2) {
        for (int i = 0; i < this.mutableRules.size(); i++) {
            if (dArr[i] != 0.0d) {
                this.mutableRules.get(i).setWeight(dArr2[i]);
            }
        }
        this.inMPEState = false;
        this.inLatentMPEState = false;
        ADMMReasoner aDMMReasoner = (ADMMReasoner) this.reasoner;
        computeObservedIncompatibility();
        double lagrangianPenalty = aDMMReasoner.getLagrangianPenalty();
        double augmentedLagrangianPenalty = aDMMReasoner.getAugmentedLagrangianPenalty();
        computeExpectedIncompatibility();
        double lagrangianPenalty2 = aDMMReasoner.getLagrangianPenalty();
        double augmentedLagrangianPenalty2 = aDMMReasoner.getAugmentedLagrangianPenalty();
        double d = 0.0d;
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            d += dArr2[i2] * (this.observedIncompatibility[i2] - this.expectedIncompatibility[i2]);
        }
        double d2 = d + (((lagrangianPenalty + augmentedLagrangianPenalty) - lagrangianPenalty2) - augmentedLagrangianPenalty2);
        for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
            log.debug("Incompatibility for rule {}", this.mutableRules.get(i3));
            log.debug("Truth incompatbility {}, expected incompatibility {}", Double.valueOf(this.observedIncompatibility[i3]), Double.valueOf(this.expectedIncompatibility[i3]));
        }
        log.debug("E Penalty: {}, E Aug Penalty: {}, M Penalty: {}, M Aug Penalty: {}", Double.valueOf(lagrangianPenalty), Double.valueOf(augmentedLagrangianPenalty), Double.valueOf(lagrangianPenalty2), Double.valueOf(augmentedLagrangianPenalty2));
        double computeRegularizer = computeRegularizer();
        if (dArr != null) {
            for (int i4 = 0; i4 < this.mutableRules.size(); i4++) {
                dArr[i4] = this.observedIncompatibility[i4] - this.expectedIncompatibility[i4];
                if (this.scaleGradient && this.groundRuleStore.count(this.mutableRules.get(i4)) > 0.0d) {
                    int i5 = i4;
                    dArr[i5] = dArr[i5] / this.groundRuleStore.count(this.mutableRules.get(i4));
                }
                int i6 = i4;
                dArr[i6] = dArr[i6] + (this.l2Regularization * dArr2[i4]) + this.l1Regularization;
            }
        }
        return d2 + computeRegularizer;
    }
}
