package org.linqs.psl.application.learning.weight.gradient.minimizer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.gradient.GradientDescent;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.atom.UnmanagedObservedAtom;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.model.rule.arithmetic.WeightedArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.WeightedGroundArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.expression.ArithmeticRuleExpression;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.ConstantNumber;
import org.linqs.psl.reasoner.duallcqp.DualBCDReasoner;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.class */
public abstract class Minimizer extends GradientDescent {
    private static final Logger log;
    protected float[] latentInferenceIncompatibility;
    protected TermState[] latentInferenceTermState;
    protected float[] latentInferenceAtomValueState;
    protected float[] mapIncompatibility;
    protected float[] mapSquaredIncompatibility;
    protected float[] augmentedInferenceIncompatibility;
    protected float[] augmentedInferenceSquaredIncompatibility;
    protected TermState[] augmentedInferenceTermState;
    protected float[] augmentedInferenceAtomValueState;
    protected float[] augmentedRVAtomGradient;
    protected float[] augmentedDeepAtomGradient;
    protected List<Integer> rvAtomIndexToProxIndex;
    protected List<Integer> proxIndexToRVAtomIndex;
    protected WeightedArithmeticRule[] proxRules;
    protected UnmanagedObservedAtom[] proxRuleObservedAtoms;
    protected int[] proxRuleObservedAtomIndexes;
    protected float[] proxRuleObservedAtomValueGradient;
    protected final float proxRuleWeight;
    protected float parameterMovementTolerance;
    protected float finalParameterMovementTolerance;
    protected float constraintTolerance;
    protected float finalConstraintTolerance;
    protected boolean initializedProxRuleConstants;
    protected int outerIteration;
    protected final float initialSquaredPenaltyCoefficient;
    protected float squaredPenaltyCoefficient;
    protected float squaredPenaltyCoefficientIncreaseRate;
    protected final float initialLinearPenaltyCoefficient;
    protected float linearPenaltyCoefficient;
    static final /* synthetic */ boolean $assertionsDisabled;

    public Minimizer(List<Rule> list, Database database, Database database2, Database database3, Database database4, boolean z) {
        super(list, database, database2, database3, database4, z);
        this.latentInferenceIncompatibility = new float[this.mutableRules.size()];
        this.latentInferenceTermState = null;
        this.latentInferenceAtomValueState = null;
        this.mapIncompatibility = new float[this.mutableRules.size()];
        this.mapSquaredIncompatibility = new float[this.mutableRules.size()];
        this.augmentedInferenceIncompatibility = new float[this.mutableRules.size()];
        this.augmentedInferenceSquaredIncompatibility = new float[this.mutableRules.size()];
        this.augmentedInferenceTermState = null;
        this.augmentedInferenceAtomValueState = null;
        this.rvAtomIndexToProxIndex = new ArrayList();
        this.proxIndexToRVAtomIndex = new ArrayList();
        this.proxRules = null;
        this.proxRuleObservedAtoms = null;
        this.proxRuleObservedAtomValueGradient = null;
        this.proxRuleWeight = Options.MINIMIZER_PROX_RULE_WEIGHT.getFloat();
        this.initialSquaredPenaltyCoefficient = Options.MINIMIZER_INITIAL_SQUARED_PENALTY.getFloat();
        this.squaredPenaltyCoefficient = this.initialSquaredPenaltyCoefficient;
        this.squaredPenaltyCoefficientIncreaseRate = Options.MINIMIZER_SQUARED_PENALTY_INCREASE_RATE.getFloat();
        this.initialLinearPenaltyCoefficient = Options.MINIMIZER_INITIAL_LINEAR_PENALTY.getFloat();
        this.linearPenaltyCoefficient = this.initialLinearPenaltyCoefficient;
        this.parameterMovementTolerance = 1.0f / this.initialSquaredPenaltyCoefficient;
        this.finalParameterMovementTolerance = Options.MINIMIZER_FINAL_PARAMETER_MOVEMENT_CONVERGENCE_TOLERANCE.getFloat();
        this.constraintTolerance = (float) (1.0d / Math.pow(this.initialSquaredPenaltyCoefficient, 0.10000000149011612d));
        this.finalConstraintTolerance = Options.MINIMIZER_OBJECTIVE_DIFFERENCE_TOLERANCE.getFloat();
        this.initializedProxRuleConstants = false;
        this.outerIteration = 1;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent, org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void postInitGroundModel() {
        AtomStore atomStore = this.trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
        int i = 0;
        Iterator<GroundAtom> it = atomStore.iterator();
        while (it.hasNext()) {
            if (!it.next().isFixed()) {
                i++;
            }
        }
        boolean mergeConstants = this.trainInferenceApplication.getTermStore().getTermGenerator().getMergeConstants();
        this.trainInferenceApplication.getTermStore().getTermGenerator().setMergeConstants(false);
        this.proxRules = new WeightedArithmeticRule[i];
        this.proxRuleObservedAtoms = new UnmanagedObservedAtom[i];
        this.proxRuleObservedAtomIndexes = new int[i];
        this.proxRuleObservedAtomValueGradient = new float[i];
        int size = atomStore.size();
        int i2 = 0;
        for (int i3 = 0; i3 < size; i3++) {
            GroundAtom atom = atomStore.getAtom(i3);
            if (atom.isFixed()) {
                this.rvAtomIndexToProxIndex.add(-1);
            } else {
                this.rvAtomIndexToProxIndex.add(Integer.valueOf(i2));
                this.proxIndexToRVAtomIndex.add(Integer.valueOf(i3));
                StandardPredicate standardPredicate = StandardPredicate.get(String.format("augmented%s", atom.getPredicate().getName()), atom.getPredicate().getArgumentTypes());
                if (Predicate.get(standardPredicate.getName()) == null) {
                    Predicate.registerPredicate(standardPredicate);
                } else if (!$assertionsDisabled && !Predicate.get(standardPredicate.getName()).equals(standardPredicate)) {
                    throw new AssertionError("The 'augmented' prefix on predicate names is reserved for weight learning functionality.");
                }
                this.proxRuleObservedAtoms[i2] = new UnmanagedObservedAtom(standardPredicate, atom.getArguments(), atom.getValue());
                atomStore.addAtom(this.proxRuleObservedAtoms[i2]);
                this.proxRuleObservedAtomIndexes[i2] = atomStore.getAtomIndex(this.proxRuleObservedAtoms[i2]);
                this.proxRules[i2] = new WeightedArithmeticRule(new ArithmeticRuleExpression(Arrays.asList(new ConstantNumber(1.0f), new ConstantNumber(-1.0f)), Arrays.asList(atom, this.proxRuleObservedAtoms[i2]), FunctionComparator.EQ, new ConstantNumber(0.0f)), this.proxRuleWeight, true);
                this.proxRules[i2].setActive(false);
                this.trainInferenceApplication.getTermStore().add(new WeightedGroundArithmeticRule(this.proxRules[i2], (List<Float>) Arrays.asList(Float.valueOf(1.0f), Float.valueOf(-1.0f)), (List<GroundAtom>) Arrays.asList(atom, this.proxRuleObservedAtoms[i2]), FunctionComparator.LTE, 0.0f));
                this.trainInferenceApplication.getTermStore().add(new WeightedGroundArithmeticRule(this.proxRules[i2], (List<Float>) Arrays.asList(Float.valueOf(1.0f), Float.valueOf(-1.0f)), (List<GroundAtom>) Arrays.asList(atom, this.proxRuleObservedAtoms[i2]), FunctionComparator.GTE, 0.0f));
                i2++;
            }
        }
        this.trainInferenceApplication.getTermStore().getTermGenerator().setMergeConstants(mergeConstants);
        super.postInitGroundModel();
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        this.latentInferenceTermState = this.trainInferenceApplication.getTermStore().saveState();
        this.latentInferenceAtomValueState = Arrays.copyOf(atomValues, atomValues.length);
        this.augmentedInferenceTermState = this.trainInferenceApplication.getTermStore().saveState();
        this.augmentedInferenceAtomValueState = Arrays.copyOf(atomValues, atomValues.length);
        this.augmentedRVAtomGradient = new float[atomValues.length];
        this.augmentedDeepAtomGradient = new float[atomValues.length];
    }

    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent
    protected boolean breakOptimization(int i, float f, float f2) {
        if (i > this.maxNumSteps) {
            log.trace("Breaking Weight Learning. Reached maximum number of iterations: {}", Integer.valueOf(this.maxNumSteps));
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        float computeObjectiveDifference = computeObjectiveDifference();
        if (computeObjectiveDifference >= this.finalConstraintTolerance) {
            return false;
        }
        log.trace("Breaking Weight Learning. Objective difference {} is less than final constraint tolerance {}.", Float.valueOf(computeObjectiveDifference), Float.valueOf(this.finalConstraintTolerance));
        return true;
    }

    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent
    protected void gradientStep(int i) {
        this.parameterMovement = 0.0f;
        this.parameterMovement += weightGradientStep(i);
        this.parameterMovement += internalParameterGradientStep(i);
        this.parameterMovement += atomGradientStep();
        float computeObjectiveDifference = computeObjectiveDifference();
        if (i > 0 && this.parameterMovement < this.parameterMovementTolerance) {
            this.outerIteration++;
            if (computeObjectiveDifference >= this.constraintTolerance) {
                this.squaredPenaltyCoefficient = this.squaredPenaltyCoefficientIncreaseRate * this.squaredPenaltyCoefficient;
                this.constraintTolerance = (float) (1.0d / Math.pow(this.squaredPenaltyCoefficient, 0.1d));
                this.parameterMovementTolerance = 1.0f / this.squaredPenaltyCoefficient;
            } else {
                if (computeObjectiveDifference < this.finalConstraintTolerance && this.parameterMovement < this.finalParameterMovementTolerance) {
                    return;
                }
                this.linearPenaltyCoefficient += 2.0f * this.squaredPenaltyCoefficient * computeObjectiveDifference;
                this.constraintTolerance = (float) (this.constraintTolerance / Math.pow(this.squaredPenaltyCoefficient, 0.9d));
                this.parameterMovementTolerance /= this.squaredPenaltyCoefficient;
            }
        }
        log.trace("Outer iteration: {}, Objective Difference: {}, Parameter Movement: {}, Squared Penalty Coefficient: {}, Linear Penalty Coefficient: {}, Constraint Tolerance: {}, parameterMovementTolerance: {}.", Integer.valueOf(this.outerIteration), Float.valueOf(computeObjectiveDifference), Float.valueOf(this.parameterMovement), Float.valueOf(this.squaredPenaltyCoefficient), Float.valueOf(this.linearPenaltyCoefficient), Float.valueOf(this.constraintTolerance), Float.valueOf(this.parameterMovementTolerance));
    }

    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent
    protected float internalParameterGradientStep(int i) {
        float f = 0.0f;
        float computeStepSize = computeStepSize(i);
        float[] atomValues = this.trainInferenceApplication.getTermStore().getDatabase().getAtomStore().getAtomValues();
        for (int i2 = 0; i2 < this.proxRules.length; i2++) {
            float min = Math.min(Math.max(this.proxRuleObservedAtoms[i2].getValue() - (computeStepSize * this.proxRuleObservedAtomValueGradient[i2]), 0.0f), 1.0f);
            f += Math.abs(this.proxRuleObservedAtoms[i2].getValue() - min);
            this.proxRuleObservedAtoms[i2]._assumeValue(min);
            atomValues[this.proxRuleObservedAtomIndexes[i2]] = min;
            this.augmentedInferenceAtomValueState[this.proxRuleObservedAtomIndexes[i2]] = min;
        }
        return f;
    }

    protected void initializeProximityRuleConstants() {
        fixLabeledRandomVariables();
        log.trace("Performing Latent Inference.");
        computeMAPStateWithWarmStart(this.trainInferenceApplication, this.latentInferenceTermState, this.latentInferenceAtomValueState);
        this.inTrainingMAPState = true;
        unfixLabeledRandomVariables();
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        float[] atomValues = atomStore.getAtomValues();
        System.arraycopy(this.latentInferenceAtomValueState, 0, this.augmentedInferenceAtomValueState, 0, this.latentInferenceAtomValueState.length);
        for (int i = 0; i < this.proxRules.length; i++) {
            this.proxRuleObservedAtoms[i]._assumeValue(this.latentInferenceAtomValueState[this.proxIndexToRVAtomIndex.get(i).intValue()]);
            atomValues[this.proxRuleObservedAtomIndexes[i]] = this.latentInferenceAtomValueState[this.proxIndexToRVAtomIndex.get(i).intValue()];
            this.augmentedInferenceAtomValueState[this.proxRuleObservedAtomIndexes[i]] = this.latentInferenceAtomValueState[this.proxIndexToRVAtomIndex.get(i).intValue()];
        }
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom key = entry.getKey();
            ObservedAtom value = entry.getValue();
            int intValue = this.rvAtomIndexToProxIndex.get(atomStore.getAtomIndex(key)).intValue();
            this.proxRuleObservedAtoms[intValue]._assumeValue(value.getValue());
            atomValues[this.proxRuleObservedAtomIndexes[intValue]] = value.getValue();
            this.augmentedInferenceAtomValueState[this.proxRuleObservedAtomIndexes[intValue]] = value.getValue();
        }
        this.initializedProxRuleConstants = true;
    }

    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent
    protected void computeIterationStatistics() {
        computeFullInferenceStatistics();
        if (!this.initializedProxRuleConstants) {
            initializeProximityRuleConstants();
        }
        computeAugmentedInferenceStatistics();
        computeProxRuleObservedAtomValueGradient();
    }

    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent
    protected void computeTotalAtomGradient() {
        float computeObjectiveDifference = computeObjectiveDifference();
        for (int i = 0; i < this.trainInferenceApplication.getDatabase().getAtomStore().size(); i++) {
            float f = this.augmentedRVAtomGradient[i] - this.MAPRVAtomGradient[i];
            float f2 = this.augmentedDeepAtomGradient[i] - this.MAPDeepAtomGradient[i];
            this.rvAtomGradient[i] = (this.squaredPenaltyCoefficient * computeObjectiveDifference * f) + (this.linearPenaltyCoefficient * f);
            this.deepAtomGradient[i] = (this.squaredPenaltyCoefficient * computeObjectiveDifference * f2) + (this.linearPenaltyCoefficient * f2);
        }
    }

    protected void computeProxRuleObservedAtomValueGradient() {
        Arrays.fill(this.proxRuleObservedAtomValueGradient, 0.0f);
        addSupervisedProxRuleObservedAtomValueGradient();
        addAugmentedLagrangianProxRuleConstantsGradient();
    }

    private void computeFullInferenceStatistics() {
        log.trace("Running Inference.");
        computeMAPStateWithWarmStart(this.trainInferenceApplication, this.trainMAPTermState, this.trainMAPAtomValueState);
        this.inTrainingMAPState = true;
        computeCurrentIncompatibility(this.mapIncompatibility);
        computeCurrentSquaredIncompatibility(this.mapSquaredIncompatibility);
        this.trainInferenceApplication.getReasoner().computeOptimalValueGradient(this.trainInferenceApplication.getTermStore(), this.MAPRVAtomGradient, this.MAPDeepAtomGradient);
    }

    protected void computeAugmentedInferenceStatistics() {
        activateAugmentedInferenceProxTerms();
        log.trace("Running Augmented Inference.");
        computeMAPStateWithWarmStart(this.trainInferenceApplication, this.augmentedInferenceTermState, this.augmentedInferenceAtomValueState);
        this.inTrainingMAPState = true;
        computeCurrentIncompatibility(this.augmentedInferenceIncompatibility);
        computeCurrentSquaredIncompatibility(this.augmentedInferenceSquaredIncompatibility);
        this.trainInferenceApplication.getReasoner().computeOptimalValueGradient(this.trainInferenceApplication.getTermStore(), this.augmentedRVAtomGradient, this.augmentedDeepAtomGradient);
        deactivateAugmentedInferenceProxTerms();
    }

    private void activateAugmentedInferenceProxTerms() {
        for (WeightedArithmeticRule weightedArithmeticRule : this.proxRules) {
            weightedArithmeticRule.setActive(true);
        }
        this.inTrainingMAPState = false;
    }

    private void deactivateAugmentedInferenceProxTerms() {
        for (WeightedArithmeticRule weightedArithmeticRule : this.proxRules) {
            weightedArithmeticRule.setActive(false);
        }
        this.inTrainingMAPState = false;
    }

    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent
    protected float computeLearningLoss() {
        float computeObjectiveDifference = computeObjectiveDifference();
        float computeSupervisedLoss = computeSupervisedLoss();
        log.trace("Total Prox Loss: {}, Total objective difference: {}, Supervised Loss: {}", Float.valueOf(computeTotalProxValue(new float[this.proxRuleObservedAtoms.length])), Float.valueOf(computeObjectiveDifference), Float.valueOf(computeSupervisedLoss));
        return ((this.squaredPenaltyCoefficient / 2.0f) * ((float) Math.pow(computeObjectiveDifference, 2.0d))) + (this.linearPenaltyCoefficient * computeObjectiveDifference) + computeSupervisedLoss;
    }

    private float computeObjectiveDifference() {
        return computeTotalEnergyDifference(new float[this.mutableRules.size()]) + computeTotalProxValue(new float[this.proxRuleObservedAtoms.length]);
    }

    protected abstract float computeSupervisedLoss();

    protected void addAugmentedLagrangianProxRuleConstantsGradient() {
        float computeTotalEnergyDifference = computeTotalEnergyDifference(new float[this.mutableRules.size()]);
        float[] fArr = new float[this.proxRuleObservedAtoms.length];
        float computeTotalProxValue = computeTotalProxValue(fArr);
        float[] fArr2 = new float[this.proxRuleObservedAtoms.length];
        for (int i = 0; i < this.proxRuleObservedAtoms.length; i++) {
            fArr2[i] = 2.0f * this.proxRuleWeight * fArr[i];
        }
        for (int i2 = 0; i2 < this.proxRuleObservedAtoms.length; i2++) {
            float[] fArr3 = this.proxRuleObservedAtomValueGradient;
            int i3 = i2;
            fArr3[i3] = fArr3[i3] + (this.linearPenaltyCoefficient * fArr2[i2]);
            float[] fArr4 = this.proxRuleObservedAtomValueGradient;
            int i4 = i2;
            fArr4[i4] = fArr4[i4] + (this.squaredPenaltyCoefficient * (computeTotalEnergyDifference + computeTotalProxValue) * fArr2[i2]);
        }
    }

    protected abstract void addSupervisedProxRuleObservedAtomValueGradient();

    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent
    protected void addLearningLossWeightGradient() {
        float[] fArr = new float[this.mutableRules.size()];
        float computeTotalEnergyDifference = computeTotalEnergyDifference(fArr);
        float computeTotalProxValue = computeTotalProxValue(new float[this.proxRuleObservedAtoms.length]);
        for (int i = 0; i < this.mutableRules.size(); i++) {
            float[] fArr2 = this.weightGradient;
            int i2 = i;
            fArr2[i2] = fArr2[i2] + (this.linearPenaltyCoefficient * fArr[i]);
            float[] fArr3 = this.weightGradient;
            int i3 = i;
            fArr3[i3] = fArr3[i3] + (this.squaredPenaltyCoefficient * (computeTotalEnergyDifference + computeTotalProxValue) * fArr[i]);
        }
    }

    private float computeTotalEnergyDifference(float[] fArr) {
        float weight;
        float f;
        float f2;
        float f3;
        float f4 = this.trainInferenceApplication.getReasoner() instanceof DualBCDReasoner ? (float) ((DualBCDReasoner) this.trainInferenceApplication.getReasoner()).regularizationParameter : 0.0f;
        float f5 = 0.0f;
        for (int i = 0; i < this.mutableRules.size(); i++) {
            fArr[i] = this.augmentedInferenceIncompatibility[i] - this.mapIncompatibility[i];
            if (this.mutableRules.get(i).isSquared()) {
                weight = f5;
                f = this.mutableRules.get(i).getWeight() + f4;
                f2 = this.augmentedInferenceIncompatibility[i];
                f3 = this.mapIncompatibility[i];
            } else {
                weight = f5 + (this.mutableRules.get(i).getWeight() * (this.augmentedInferenceIncompatibility[i] - this.mapIncompatibility[i]));
                f = f4;
                f2 = this.augmentedInferenceSquaredIncompatibility[i];
                f3 = this.mapSquaredIncompatibility[i];
            }
            f5 = weight + (f * (f2 - f3));
        }
        GroundAtom[] atoms = this.trainInferenceApplication.getDatabase().getAtomStore().getAtoms();
        float f6 = 0.0f;
        float f7 = 0.0f;
        for (int i2 = 0; i2 < this.trainInferenceApplication.getDatabase().getAtomStore().size(); i2++) {
            if (!atoms[i2].isFixed()) {
                f6 = (float) (f6 + (f4 * Math.pow(this.augmentedInferenceAtomValueState[i2], 2.0d)));
                f7 = (float) (f7 + (f4 * Math.pow(this.trainMAPAtomValueState[i2], 2.0d)));
            }
        }
        return f5 + (f6 - f7);
    }

    private float computeTotalProxValue(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < this.proxRules.length; i++) {
            fArr[i] = this.proxRuleObservedAtoms[i].getValue() - this.augmentedInferenceAtomValueState[this.proxIndexToRVAtomIndex.get(i).intValue()];
            f = (float) (f + Math.pow(fArr[i], 2.0d));
        }
        return this.proxRuleWeight * f;
    }

    private void clipProxRuleObservedAtomValueGradient(float[] fArr) {
        for (int i = 0; i < fArr.length; i++) {
            if (MathUtils.isZero(this.proxRuleObservedAtoms[i].getValue()) && fArr[i] > 0.0f) {
                fArr[i] = 0.0f;
            } else if (MathUtils.equals(this.proxRuleObservedAtoms[i].getValue(), 1.0f) && fArr[i] < 0.0f) {
                fArr[i] = 0.0f;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent
    public float computeGradientNorm() {
        float computeGradientNorm = super.computeGradientNorm();
        float[] fArr = (float[]) this.proxRuleObservedAtomValueGradient.clone();
        clipProxRuleObservedAtomValueGradient(fArr);
        return computeGradientNorm + MathUtils.pNorm(fArr, this.stoppingGradientNorm);
    }

    protected void computeCurrentSquaredIncompatibility(float[] fArr) {
        Integer num;
        Arrays.fill(fArr, 0.0f);
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        Iterator it = this.trainInferenceApplication.getTermStore().iterator();
        while (it.hasNext()) {
            ReasonerTerm reasonerTerm = (ReasonerTerm) it.next();
            if ((reasonerTerm.getRule() instanceof WeightedRule) && (num = this.ruleIndexMap.get((WeightedRule) reasonerTerm.getRule())) != null) {
                int intValue = num.intValue();
                fArr[intValue] = fArr[intValue] + reasonerTerm.evaluateSquaredHingeLoss(atomValues);
            }
        }
    }

    protected void fixLabeledRandomVariables() {
        AtomStore atomStore = this.trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom key = entry.getKey();
            ObservedAtom value = entry.getValue();
            int atomIndex = atomStore.getAtomIndex(key);
            atomStore.getAtoms()[atomIndex] = value;
            atomStore.getAtomValues()[atomIndex] = value.getValue();
            this.latentInferenceAtomValueState[atomIndex] = value.getValue();
            key.setValue(value.getValue());
        }
        this.inTrainingMAPState = false;
    }

    protected void unfixLabeledRandomVariables() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        Iterator<Map.Entry<RandomVariableAtom, ObservedAtom>> it = this.trainingMap.getLabelMap().entrySet().iterator();
        while (it.hasNext()) {
            RandomVariableAtom key = it.next().getKey();
            atomStore.getAtoms()[atomStore.getAtomIndex(key)] = key;
        }
        this.inTrainingMAPState = false;
    }

    static {
        $assertionsDisabled = !Minimizer.class.desiredAssertionStatus();
        log = Logger.getLogger(Minimizer.class);
    }
}
