package org.linqs.psl.application.learning.weight.gradient.minimizer;

import java.util.List;
import java.util.Map;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/gradient/minimizer/BinaryCrossEntropy.class */
public class BinaryCrossEntropy extends Minimizer {
    public BinaryCrossEntropy(List<Rule> list, Database database, Database database2, Database database3, Database database4, boolean z) {
        super(list, database, database2, database3, database4, z);
    }

    @Override // org.linqs.psl.application.learning.weight.gradient.minimizer.Minimizer
    protected float computeSupervisedLoss() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        float f = 0.0f;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom key = entry.getKey();
            ObservedAtom value = entry.getValue();
            int intValue = this.rvAtomIndexToProxIndex.get(atomStore.getAtomIndex(key)).intValue();
            f = (float) (f + ((-1.0d) * ((value.getValue() * Math.log(Math.max(this.proxRuleObservedAtoms[intValue].getValue(), 1.0E-4f))) + ((1.0f - value.getValue()) * Math.log(Math.max(1.0f - this.proxRuleObservedAtoms[intValue].getValue(), 1.0E-4f))))));
        }
        return f;
    }

    @Override // org.linqs.psl.application.learning.weight.gradient.minimizer.Minimizer
    protected void addSupervisedProxRuleObservedAtomValueGradient() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom key = entry.getKey();
            ObservedAtom value = entry.getValue();
            int intValue = this.rvAtomIndexToProxIndex.get(atomStore.getAtomIndex(key)).intValue();
            float[] fArr = this.proxRuleObservedAtomValueGradient;
            fArr[intValue] = fArr[intValue] + ((-1.0f) * ((value.getValue() / Math.max(this.proxRuleObservedAtoms[intValue].getValue(), 1.0E-4f)) - ((1.0f - value.getValue()) / Math.max(1.0f - this.proxRuleObservedAtoms[intValue].getValue(), 1.0E-4f))));
        }
    }
}
