package org.linqs.psl.application.learning.weight.maxlikelihood;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.linqs.psl.application.learning.weight.VotedPerceptron;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.RandUtils;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/maxlikelihood/MaxPiecewisePseudoLikelihood.class */
public class MaxPiecewisePseudoLikelihood extends VotedPerceptron {
    public static final String CONFIG_PREFIX = "maxpiecewisepseudolikelihood";
    public static final String NUM_SAMPLES_KEY = "maxpiecewisepseudolikelihood.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 100;
    private final int maxNumSamples;
    private int numSamples;
    private List<Map<RandomVariableAtom, List<WeightedGroundRule>>> ruleRandomVariableMap;
    private Random[] rands;

    public MaxPiecewisePseudoLikelihood(Model model, Database database, Database database2) {
        this(model.getRules(), database, database2);
    }

    public MaxPiecewisePseudoLikelihood(List<Rule> list, Database database, Database database2) {
        super(list, database, database2, false);
        this.maxNumSamples = Config.getInt(NUM_SAMPLES_KEY, 100);
        this.numSamples = this.maxNumSamples;
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of samples must be positive.");
        }
        this.rands = new Random[Parallel.getNumThreads()];
        for (int i = 0; i < Parallel.getNumThreads(); i++) {
            this.rands[i] = new Random(RandUtils.nextLong());
        }
        this.ruleRandomVariableMap = null;
        this.averageSteps = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void postInitGroundModel() {
        populateRandomVariableMap();
    }

    private void populateRandomVariableMap() {
        this.ruleRandomVariableMap = new ArrayList();
        for (WeightedRule weightedRule : this.mutableRules) {
            HashMap hashMap = new HashMap();
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules(weightedRule)) {
                for (GroundAtom groundAtom : groundRule.getAtoms()) {
                    if (groundAtom instanceof RandomVariableAtom) {
                        RandomVariableAtom randomVariableAtom = (RandomVariableAtom) groundAtom;
                        if (!hashMap.containsKey(randomVariableAtom)) {
                            hashMap.put(randomVariableAtom, new ArrayList());
                        }
                        ((List) hashMap.get(groundAtom)).add((WeightedGroundRule) groundRule);
                    }
                }
            }
            this.ruleRandomVariableMap.add(hashMap);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void computeExpectedIncompatibility() {
        setLabeledRandomVariables();
        Parallel.count(this.mutableRules.size(), new Parallel.Worker<Integer>() { // from class: org.linqs.psl.application.learning.weight.maxlikelihood.MaxPiecewisePseudoLikelihood.1
            @Override // org.linqs.psl.util.Parallel.Worker
            public void work(int i, Integer num) {
                WeightedRule weightedRule = (WeightedRule) MaxPiecewisePseudoLikelihood.this.mutableRules.get(i);
                Map map = (Map) MaxPiecewisePseudoLikelihood.this.ruleRandomVariableMap.get(i);
                double d = 0.0d;
                double weight = weightedRule.getWeight();
                for (RandomVariableAtom randomVariableAtom : map.keySet()) {
                    List list = (List) map.get(randomVariableAtom);
                    double d2 = 0.0d;
                    double d3 = 1.0E-6d;
                    for (int i2 = 0; i2 < MaxPiecewisePseudoLikelihood.this.numSamples; i2++) {
                        float nextFloat = MaxPiecewisePseudoLikelihood.this.rands[this.id].nextFloat();
                        double d4 = 0.0d;
                        for (int i3 = 0; i3 < list.size(); i3++) {
                            d4 += ((WeightedGroundRule) list.get(i3)).getIncompatibility(randomVariableAtom, nextFloat);
                        }
                        d2 += Math.exp((-1.0d) * weight * d4) * d4;
                        d3 += Math.exp((-1.0d) * weight * d4);
                    }
                    d += d2 / d3;
                }
                MaxPiecewisePseudoLikelihood.this.expectedIncompatibility[i] = d;
            }
        });
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public double computeLoss() {
        setLabeledRandomVariables();
        final double[] dArr = new double[this.mutableRules.size()];
        Parallel.count(this.mutableRules.size(), new Parallel.Worker<Integer>() { // from class: org.linqs.psl.application.learning.weight.maxlikelihood.MaxPiecewisePseudoLikelihood.2
            @Override // org.linqs.psl.util.Parallel.Worker
            public void work(int i, Integer num) {
                Map map = (Map) MaxPiecewisePseudoLikelihood.this.ruleRandomVariableMap.get(i);
                double weight = ((WeightedRule) MaxPiecewisePseudoLikelihood.this.mutableRules.get(i)).getWeight();
                for (RandomVariableAtom randomVariableAtom : map.keySet()) {
                    List list = (List) map.get(randomVariableAtom);
                    double d = 0.0d;
                    for (int i2 = 0; i2 < MaxPiecewisePseudoLikelihood.this.numSamples; i2++) {
                        float nextFloat = MaxPiecewisePseudoLikelihood.this.rands[this.id].nextFloat();
                        double d2 = 0.0d;
                        for (int i3 = 0; i3 < list.size(); i3++) {
                            d2 -= ((WeightedGroundRule) list.get(i3)).getIncompatibility(randomVariableAtom, nextFloat);
                        }
                        d += Math.exp(weight * d2);
                    }
                    double d3 = 0.0d;
                    for (int i4 = 0; i4 < list.size(); i4++) {
                        d3 += (-1.0d) * weight * ((WeightedGroundRule) list.get(i4)).getIncompatibility();
                    }
                    double log = (-1.0d) * Math.log(d / MaxPiecewisePseudoLikelihood.this.numSamples);
                    double[] dArr2 = dArr;
                    dArr2[i] = dArr2[i] + d3 + log;
                }
                double[] dArr3 = dArr;
                dArr3[i] = dArr3[i] + ((-0.5d) * MaxPiecewisePseudoLikelihood.this.l2Regularization * Math.pow(weight, 2.0d));
            }
        });
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void computeObservedIncompatibility() {
        setLabeledRandomVariables();
        for (int i = 0; i < this.mutableRules.size(); i++) {
            WeightedRule weightedRule = this.mutableRules.get(i);
            Map<RandomVariableAtom, List<WeightedGroundRule>> map = this.ruleRandomVariableMap.get(i);
            weightedRule.getWeight();
            double d = 0.0d;
            Iterator<RandomVariableAtom> it = map.keySet().iterator();
            while (it.hasNext()) {
                Iterator<WeightedGroundRule> it2 = map.get(it.next()).iterator();
                while (it2.hasNext()) {
                    d += it2.next().getIncompatibility();
                }
            }
            this.observedIncompatibility[i] = d;
        }
    }

    @Override // org.linqs.psl.application.learning.weight.VotedPerceptron, org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void setBudget(double d) {
        super.setBudget(d);
        this.numSamples = (int) Math.ceil(d * this.maxNumSamples);
    }
}
