package org.linqs.psl.application.learning.weight;

import java.util.Iterator;
import java.util.List;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/VotedPerceptron.class */
public abstract class VotedPerceptron extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) VotedPerceptron.class);
    public static final String CONFIG_PREFIX = "votedperceptron";
    public static final String L2_REGULARIZATION_KEY = "votedperceptron.l2regularization";
    public static final double L2_REGULARIZATION_DEFAULT = 0.0d;
    public static final String L1_REGULARIZATION_KEY = "votedperceptron.l1regularization";
    public static final double L1_REGULARIZATION_DEFAULT = 0.0d;
    public static final String STEP_SIZE_KEY = "votedperceptron.stepsize";
    public static final double STEP_SIZE_DEFAULT = 0.2d;
    public static final String INERTIA_KEY = "votedperceptron.inertia";
    public static final double INERTIA_DEFAULT = 0.0d;
    public static final String SCALE_GRADIENT_KEY = "votedperceptron.scalegradient";
    public static final boolean SCALE_GRADIENT_DEFAULT = true;
    public static final String AVERAGE_STEPS_KEY = "votedperceptron.averagesteps";
    public static final boolean AVERAGE_STEPS_DEFAULT = false;
    public static final String NUM_STEPS_KEY = "votedperceptron.numsteps";
    public static final int NUM_STEPS_DEFAULT = 25;
    public static final String CLIP_NEGATIVE_WEIGHTS_KEY = "votedperceptron.clipnegativeweights";
    public static final boolean CLIP_NEGATIVE_WEIGHTS_DEFAULT = true;
    public static final String CUT_OBJECTIVE_KEY = "votedperceptron.cutobjective";
    public static final boolean CUT_OBJECTIVE_DEFAULT = false;
    public static final String SCALE_STEP_SIZE_KEY = "votedperceptron.scalestepsize";
    public static final boolean SCALE_STEP_SIZE_DEFAULT = true;
    public static final String ZERO_INITIAL_WEIGHTS_KEY = "votedperceptron.zeroinitialweights";
    public static final boolean ZERO_INITIAL_WEIGHTS_DEFAULT = false;
    protected final double l2Regularization;
    protected final double l1Regularization;
    protected final boolean scaleGradient;
    protected double baseStepSize;
    protected boolean scaleStepSize;
    protected boolean averageSteps;
    protected boolean zeroInitialWeights;
    protected boolean clipNegativeWeights;
    protected boolean cutObjective;
    protected double inertia;
    protected final int maxNumSteps;
    protected int numSteps;
    private double currentLoss;

    public VotedPerceptron(List<Rule> list, Database database, Database database2, boolean z) {
        super(list, database, database2, z);
        this.baseStepSize = Config.getDouble(STEP_SIZE_KEY, 0.2d);
        if (this.baseStepSize <= 0.0d) {
            throw new IllegalArgumentException("Step size must be positive.");
        }
        this.inertia = Config.getDouble(INERTIA_KEY, 0.0d);
        if (this.inertia < 0.0d || this.inertia >= 1.0d) {
            throw new IllegalArgumentException("Inertia must be in [0, 1), found: " + this.inertia);
        }
        this.numSteps = Config.getInt(NUM_STEPS_KEY, 25);
        this.maxNumSteps = this.numSteps;
        if (this.numSteps <= 0) {
            throw new IllegalArgumentException("Number of steps must be positive.");
        }
        this.l2Regularization = Config.getDouble(L2_REGULARIZATION_KEY, 0.0d);
        if (this.l2Regularization < 0.0d) {
            throw new IllegalArgumentException("L2 regularization parameter must be non-negative.");
        }
        this.l1Regularization = Config.getDouble(L1_REGULARIZATION_KEY, 0.0d);
        if (this.l1Regularization < 0.0d) {
            throw new IllegalArgumentException("L1 regularization parameter must be non-negative.");
        }
        this.scaleGradient = Config.getBoolean(SCALE_GRADIENT_KEY, true);
        this.averageSteps = Config.getBoolean(AVERAGE_STEPS_KEY, false);
        this.scaleStepSize = Config.getBoolean(SCALE_STEP_SIZE_KEY, true);
        this.zeroInitialWeights = Config.getBoolean(ZERO_INITIAL_WEIGHTS_KEY, false);
        this.clipNegativeWeights = Config.getBoolean(CLIP_NEGATIVE_WEIGHTS_KEY, true);
        this.cutObjective = Config.getBoolean(CUT_OBJECTIVE_KEY, false);
        this.currentLoss = Double.NaN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void doLearn() {
        double[] dArr = new double[this.mutableRules.size()];
        computeObservedIncompatibility();
        setDefaultRandomVariables();
        if (this.zeroInitialWeights) {
            Iterator<WeightedRule> it = this.mutableRules.iterator();
            while (it.hasNext()) {
                it.next().setWeight(0.0d);
            }
        }
        if (log.isDebugEnabled() && this.evaluator != null) {
            computeMPEState();
            this.evaluator.compute(this.trainingMap);
            double representativeMetric = this.evaluator.getRepresentativeMetric();
            log.debug("Initial Training Objective: {}", Double.valueOf(this.evaluator.isHigherRepresentativeBetter() ? (-1.0d) * representativeMetric : representativeMetric));
        }
        double[] computeScalingFactor = computeScalingFactor();
        double[] dArr2 = new double[this.mutableRules.size()];
        double d = -1.0d;
        double[] dArr3 = new double[this.mutableRules.size()];
        for (int i = 0; i < this.mutableRules.size(); i++) {
            dArr3[i] = this.mutableRules.get(i).getWeight();
        }
        for (int i2 = 0; i2 < this.numSteps; i2++) {
            log.debug("Starting iteration {}", Integer.valueOf(i2));
            this.currentLoss = Double.NaN;
            computeExpectedIncompatibility();
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
                double weight = this.mutableRules.get(i3).getWeight();
                double d3 = ((((this.expectedIncompatibility[i3] - this.observedIncompatibility[i3]) - (this.l2Regularization * weight)) - this.l1Regularization) / computeScalingFactor[i3]) * this.baseStepSize;
                if (this.scaleStepSize) {
                    d3 /= i2 + 1;
                }
                double d4 = d3 + (this.inertia * dArr2[i3]);
                double max = this.clipNegativeWeights ? Math.max(0.0d, weight + d4) : weight + d4;
                log.trace("Gradient: {} (without momentun: {}), Expected Incomp.: {}, Observed Incomp.: {} -- ({}) {}", Double.valueOf(d4), Double.valueOf(d4 - (this.inertia * dArr2[i3])), Double.valueOf(this.expectedIncompatibility[i3]), Double.valueOf(this.observedIncompatibility[i3]), Integer.valueOf(i3), this.mutableRules.get(i3));
                this.mutableRules.get(i3).setWeight(max);
                dArr2[i3] = d4;
                int i4 = i3;
                dArr[i4] = dArr[i4] + max;
                d2 += Math.pow(this.expectedIncompatibility[i3] - this.observedIncompatibility[i3], 2.0d);
            }
            this.inMPEState = false;
            this.inLatentMPEState = false;
            double sqrt = Math.sqrt(d2);
            if (log.isDebugEnabled()) {
                getLoss();
            }
            double d5 = -1.0d;
            if ((this.cutObjective || log.isDebugEnabled()) && this.evaluator != null) {
                computeMPEState();
                this.evaluator.compute(this.trainingMap);
                double representativeMetric2 = this.evaluator.getRepresentativeMetric();
                d5 = this.evaluator.isHigherRepresentativeBetter() ? (-1.0d) * representativeMetric2 : representativeMetric2;
                if (!this.cutObjective || i2 <= 0 || d5 <= d) {
                    d = d5;
                } else {
                    log.trace("Objective increased: {} -> {}, cutting step size: {} -> {}.", Double.valueOf(d), Double.valueOf(d5), Double.valueOf(this.baseStepSize), Double.valueOf(this.baseStepSize / 2.0d));
                    this.baseStepSize /= 2.0d;
                    d5 = d;
                    for (int i5 = 0; i5 < this.mutableRules.size(); i5++) {
                        dArr2[i5] = 0.0d;
                        int i6 = i5;
                        dArr[i6] = dArr[i6] - this.mutableRules.get(i5).getWeight();
                        this.mutableRules.get(i5).setWeight(dArr3[i5]);
                    }
                }
            }
            for (int i7 = 0; i7 < this.mutableRules.size(); i7++) {
                dArr3[i7] = this.mutableRules.get(i7).getWeight();
            }
            log.debug("Iteration {} complete. Likelihood: {}. Training Objective: {}, Icomp. L2-norm: {}", Integer.valueOf(i2), Double.valueOf(this.currentLoss), Double.valueOf(d5), Double.valueOf(sqrt));
            log.trace("Model {} ", this.mutableRules);
        }
        if (this.averageSteps) {
            for (int i8 = 0; i8 < this.mutableRules.size(); i8++) {
                this.mutableRules.get(i8).setWeight(dArr[i8] / this.numSteps);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double computeRegularizer() {
        if (this.l1Regularization == 0.0d && this.l2Regularization == 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (WeightedRule weightedRule : this.mutableRules) {
            d += Math.pow(weightedRule.getWeight(), 2.0d);
            d2 += Math.abs(weightedRule.getWeight());
        }
        return (0.5d * this.l2Regularization * d) + (this.l1Regularization * d2);
    }

    public double getLoss() {
        if (Double.isNaN(this.currentLoss)) {
            this.currentLoss = computeLoss();
        }
        return this.currentLoss;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] computeScalingFactor() {
        double[] dArr = new double[this.mutableRules.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = Math.max(1.0d, this.groundRuleStore.count(this.mutableRules.get(i)));
        }
        return dArr;
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void setBudget(double d) {
        super.setBudget(d);
        this.numSteps = (int) Math.ceil(d * this.maxNumSteps);
    }
}
