package org.linqs.psl.application.learning.weight.gradient.optimalvalue;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.gradient.GradientDescent;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/gradient/optimalvalue/OptimalValue.class */
public abstract class OptimalValue extends GradientDescent {
    private static final Logger log = Logger.getLogger(GradientDescent.class);
    protected float[] latentInferenceIncompatibility;
    protected TermState[] latentInferenceTermState;
    protected float[] latentInferenceAtomValueState;
    protected float[] rvLatentAtomGradient;
    protected float[] deepLatentAtomGradient;

    public OptimalValue(List<Rule> list, Database database, Database database2, Database database3, Database database4, boolean z) {
        super(list, database, database2, database3, database4, z);
        this.latentInferenceIncompatibility = new float[this.mutableRules.size()];
        this.latentInferenceTermState = null;
        this.latentInferenceAtomValueState = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.gradient.GradientDescent, org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void postInitGroundModel() {
        super.postInitGroundModel();
        this.latentInferenceTermState = this.trainInferenceApplication.getTermStore().saveState();
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        this.latentInferenceAtomValueState = Arrays.copyOf(atomValues, atomValues.length);
        this.rvLatentAtomGradient = new float[atomValues.length];
        this.deepLatentAtomGradient = new float[atomValues.length];
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeLatentInferenceIncompatibility() {
        fixLabeledRandomVariables();
        log.trace("Running Latent Inference.");
        computeMAPStateWithWarmStart(this.trainInferenceApplication, this.latentInferenceTermState, this.latentInferenceAtomValueState);
        this.inTrainingMAPState = true;
        computeCurrentIncompatibility(this.latentInferenceIncompatibility);
        this.trainInferenceApplication.getReasoner().computeOptimalValueGradient(this.trainInferenceApplication.getTermStore(), this.rvLatentAtomGradient, this.deepLatentAtomGradient);
        for (int i = 0; i < this.mutableRules.size(); i++) {
            log.trace("Rule: {} , Latent inference incompatibility: {}", this.mutableRules.get(i), Float.valueOf(this.latentInferenceIncompatibility[i]));
        }
        unfixLabeledRandomVariables();
    }

    protected void fixLabeledRandomVariables() {
        AtomStore atomStore = this.trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom key = entry.getKey();
            ObservedAtom value = entry.getValue();
            int atomIndex = atomStore.getAtomIndex(key);
            atomStore.getAtoms()[atomIndex] = value;
            atomStore.getAtomValues()[atomIndex] = value.getValue();
            this.latentInferenceAtomValueState[atomIndex] = value.getValue();
            key.setValue(value.getValue());
        }
        this.inTrainingMAPState = false;
    }

    protected void unfixLabeledRandomVariables() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        Iterator<Map.Entry<RandomVariableAtom, ObservedAtom>> it = this.trainingMap.getLabelMap().entrySet().iterator();
        while (it.hasNext()) {
            RandomVariableAtom key = it.next().getKey();
            atomStore.getAtoms()[atomStore.getAtomIndex(key)] = key;
        }
        this.inTrainingMAPState = false;
    }
}
