package org.linqs.psl.application.learning.weight;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/VotedPerceptron.class */
public abstract class VotedPerceptron extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(VotedPerceptron.class);
    protected double[] observedIncompatibility;
    protected double[] expectedIncompatibility;
    protected final double l2Regularization;
    protected final double l1Regularization;
    protected final boolean scaleGradient;
    protected double baseStepSize;
    protected boolean scaleStepSize;
    protected boolean averageSteps;
    protected boolean zeroInitialWeights;
    protected boolean clipNegativeWeights;
    protected boolean cutObjective;
    protected double inertia;
    protected final int maxNumSteps;
    protected int numSteps;
    private double currentLoss;

    public VotedPerceptron(List<Rule> list, Database database, Database database2) {
        super(list, database, database2);
        this.observedIncompatibility = new double[this.mutableRules.size()];
        this.expectedIncompatibility = new double[this.mutableRules.size()];
        this.numSteps = Options.WLA_VP_NUM_STEPS.getInt();
        this.maxNumSteps = this.numSteps;
        this.baseStepSize = Options.WLA_VP_STEP.getDouble();
        this.inertia = Options.WLA_VP_INERTIA.getDouble();
        this.l1Regularization = Options.WLA_VP_L1.getDouble();
        this.l2Regularization = Options.WLA_VP_L2.getDouble();
        this.scaleGradient = Options.WLA_VP_SCALE_GRADIENT.getBoolean();
        this.averageSteps = Options.WLA_VP_AVERAGE_STEPS.getBoolean();
        this.scaleStepSize = Options.WLA_VP_SCALE_STEP.getBoolean();
        this.zeroInitialWeights = Options.WLA_VP_ZERO_INITIAL_WEIGHTS.getBoolean();
        this.clipNegativeWeights = Options.WLA_VP_CLIP_NEGATIVE_WEIGHTS.getBoolean();
        this.cutObjective = Options.WLA_VP_CUT_OBJECTIVE.getBoolean();
        this.currentLoss = Double.NaN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void postInitGroundModel() {
        if (this.trainingMap.getLatentVariables().size() > 0) {
            log.warn("Latent variable(s) found when using a VotedPerceptron-based weight learning method ({}). VotedPerceptron uses gradients to update weights, but latent variables may make the gradients less accurate. Weight learning may still perform sufficiently. Found {} latent variables. Example latent variable: [{}].", new Object[]{getClass().getName(), Integer.valueOf(this.trainingMap.getLatentVariables().size()), this.trainingMap.getLatentVariables().get(0)});
        }
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    protected void doLearn() {
        double[] dArr = new double[this.mutableRules.size()];
        computeObservedIncompatibility();
        setDefaultRandomVariables();
        if (this.zeroInitialWeights) {
            Iterator<WeightedRule> it = this.mutableRules.iterator();
            while (it.hasNext()) {
                it.next().setWeight(0.0d);
            }
        }
        if (log.isDebugEnabled() && this.evaluator != null) {
            computeMPEState();
            this.evaluator.compute(this.trainingMap);
            log.debug("Initial Training Objective: {}", Double.valueOf((-1.0d) * this.evaluator.getNormalizedRepMetric()));
        }
        double[] computeScalingFactor = computeScalingFactor();
        double[] dArr2 = new double[this.mutableRules.size()];
        double d = -1.0d;
        double[] dArr3 = new double[this.mutableRules.size()];
        for (int i = 0; i < this.mutableRules.size(); i++) {
            dArr3[i] = this.mutableRules.get(i).getWeight();
        }
        for (int i2 = 0; i2 < this.numSteps; i2++) {
            log.debug("Starting iteration {}", Integer.valueOf(i2));
            this.currentLoss = Double.NaN;
            computeExpectedIncompatibility();
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
                double weight = this.mutableRules.get(i3).getWeight();
                double d3 = ((((this.expectedIncompatibility[i3] - this.observedIncompatibility[i3]) - (this.l2Regularization * weight)) - this.l1Regularization) / computeScalingFactor[i3]) * this.baseStepSize;
                if (this.scaleStepSize) {
                    d3 /= i2 + 1;
                }
                double d4 = d3 + (this.inertia * dArr2[i3]);
                double max = this.clipNegativeWeights ? Math.max(0.0d, weight + d4) : weight + d4;
                log.trace("Gradient: {} (without momentun: {}), Expected Incomp.: {}, Observed Incomp.: {} -- ({}) {}", new Object[]{Double.valueOf(d4), Double.valueOf(d4 - (this.inertia * dArr2[i3])), Double.valueOf(this.expectedIncompatibility[i3]), Double.valueOf(this.observedIncompatibility[i3]), Integer.valueOf(i3), this.mutableRules.get(i3)});
                this.mutableRules.get(i3).setWeight(max);
                dArr2[i3] = d4;
                int i4 = i3;
                dArr[i4] = dArr[i4] + max;
                d2 += Math.pow(this.expectedIncompatibility[i3] - this.observedIncompatibility[i3], 2.0d);
            }
            this.inMPEState = false;
            double sqrt = Math.sqrt(d2);
            if (log.isDebugEnabled()) {
                getLoss();
            }
            double d5 = -1.0d;
            if ((this.cutObjective || log.isDebugEnabled()) && this.evaluator != null) {
                computeMPEState();
                this.evaluator.compute(this.trainingMap);
                d5 = (-1.0d) * this.evaluator.getNormalizedRepMetric();
                if (!this.cutObjective || i2 <= 0 || d5 <= d) {
                    d = d5;
                } else {
                    log.trace("Objective increased: {} -> {}, cutting step size: {} -> {}.", new Object[]{Double.valueOf(d), Double.valueOf(d5), Double.valueOf(this.baseStepSize), Double.valueOf(this.baseStepSize / 2.0d)});
                    this.baseStepSize /= 2.0d;
                    d5 = d;
                    for (int i5 = 0; i5 < this.mutableRules.size(); i5++) {
                        dArr2[i5] = 0.0d;
                        int i6 = i5;
                        dArr[i6] = dArr[i6] - this.mutableRules.get(i5).getWeight();
                        this.mutableRules.get(i5).setWeight(dArr3[i5]);
                    }
                }
            }
            for (int i7 = 0; i7 < this.mutableRules.size(); i7++) {
                dArr3[i7] = this.mutableRules.get(i7).getWeight();
            }
            log.debug("Iteration {} complete. Likelihood: {}. Training Objective: {}, Icomp. L2-norm: {}", new Object[]{Integer.valueOf(i2), Double.valueOf(this.currentLoss), Double.valueOf(d5), Double.valueOf(sqrt)});
            log.trace("Model {} ", this.mutableRules);
        }
        if (this.averageSteps) {
            for (int i8 = 0; i8 < this.mutableRules.size(); i8++) {
                this.mutableRules.get(i8).setWeight(dArr[i8] / this.numSteps);
            }
        }
    }

    protected double computeLoss() {
        double d = 0.0d;
        for (int i = 0; i < this.mutableRules.size(); i++) {
            d += this.mutableRules.get(i).getWeight() * (this.observedIncompatibility[i] - this.expectedIncompatibility[i]);
        }
        return d;
    }

    protected double computeRegularizer() {
        if (this.l1Regularization == 0.0d && this.l2Regularization == 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (WeightedRule weightedRule : this.mutableRules) {
            d += Math.pow(weightedRule.getWeight(), 2.0d);
            d2 += Math.abs(weightedRule.getWeight());
        }
        return (0.5d * this.l2Regularization * d) + (this.l1Regularization * d2);
    }

    public double getLoss() {
        if (Double.isNaN(this.currentLoss)) {
            this.currentLoss = computeLoss();
        }
        return this.currentLoss;
    }

    protected double[] computeScalingFactor() {
        double[] dArr = new double[this.mutableRules.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = Math.max(1.0d, this.inference.getGroundRuleStore().count(this.mutableRules.get(i)));
        }
        return dArr;
    }

    protected void computeObservedIncompatibility() {
        setLabeledRandomVariables();
        for (int i = 0; i < this.observedIncompatibility.length; i++) {
            this.observedIncompatibility[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            for (GroundRule groundRule : this.inference.getGroundRuleStore().getGroundRules(this.mutableRules.get(i2))) {
                double[] dArr = this.observedIncompatibility;
                int i3 = i2;
                dArr[i3] = dArr[i3] + ((WeightedGroundRule) groundRule).getIncompatibility();
            }
        }
    }

    protected void computeExpectedIncompatibility() {
        computeMPEState();
        for (int i = 0; i < this.expectedIncompatibility.length; i++) {
            this.expectedIncompatibility[i] = 0.0d;
        }
        for (int i2 = 0; i2 < this.mutableRules.size(); i2++) {
            for (GroundRule groundRule : this.inference.getGroundRuleStore().getGroundRules(this.mutableRules.get(i2))) {
                double[] dArr = this.expectedIncompatibility;
                int i3 = i2;
                dArr[i3] = dArr[i3] + ((WeightedGroundRule) groundRule).getIncompatibility();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setLabeledRandomVariables() {
        this.inMPEState = false;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            entry.getKey().setValue(entry.getValue().getValue());
        }
    }

    protected void setDefaultRandomVariables() {
        this.inMPEState = false;
        Iterator<RandomVariableAtom> it = this.trainingMap.getLabelMap().keySet().iterator();
        while (it.hasNext()) {
            it.next().setValue(0.0f);
        }
        Iterator<RandomVariableAtom> it2 = this.trainingMap.getLatentVariables().iterator();
        while (it2.hasNext()) {
            it2.next().setValue(0.0f);
        }
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void setBudget(double d) {
        super.setBudget(d);
        this.numSteps = (int) Math.ceil(d * this.maxNumSteps);
    }
}
