package org.linqs.psl.reasoner.bool;

import java.util.Collection;
import java.util.HashSet;
import org.linqs.psl.config.Config;
import org.linqs.psl.grounding.GroundRules;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.logical.WeightedGroundLogicalRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTerm;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTermStore;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.RandUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/reasoner/bool/BooleanMaxWalkSat.class */
public class BooleanMaxWalkSat implements Reasoner {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BooleanMaxWalkSat.class);
    public static final String CONFIG_PREFIX = "booleanmaxwalksat";
    public static final String MAX_FLIPS_KEY = "booleanmaxwalksat.maxflips";
    public static final int MAX_FLIPS_DEFAULT = 50000;
    public static final String NOISE_KEY = "booleanmaxwalksat.noise";
    public static final double NOISE_DEFAULT = 0.01d;
    private final int maxFlips = Config.getInt(MAX_FLIPS_KEY, 50000);
    private final double noise;

    public BooleanMaxWalkSat() {
        if (this.maxFlips <= 0) {
            throw new IllegalArgumentException("Max flips must be positive.");
        }
        this.noise = Config.getDouble(NOISE_KEY, 0.01d);
        if (this.noise < 0.0d || this.noise > 1.0d) {
            throw new IllegalArgumentException("Noise must be in [0,1].");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.linqs.psl.reasoner.Reasoner
    public void optimize(TermStore termStore) {
        int i;
        int i2;
        int blockIndex;
        if (!(termStore instanceof ConstraintBlockerTermStore)) {
            throw new IllegalArgumentException("ConstraintBlockerTermStore required.");
        }
        ConstraintBlockerTermStore constraintBlockerTermStore = (ConstraintBlockerTermStore) termStore;
        constraintBlockerTermStore.randomlyInitialize();
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        HashSet<ConstraintBlockerTerm> hashSet3 = new HashSet();
        for (GroundRule groundRule : constraintBlockerTermStore.getGroundRuleStore().getGroundRules()) {
            if ((groundRule instanceof WeightedGroundRule) && ((WeightedGroundRule) groundRule).getIncompatibility() > 0.0d) {
                hashSet.add(groundRule);
            }
        }
        int i3 = 0;
        while (i3 < this.maxFlips && hashSet.size() != 0) {
            GroundRule groundRule2 = (GroundRule) selectAtRandom(hashSet);
            hashSet2.clear();
            hashSet3.clear();
            for (GroundAtom groundAtom : groundRule2.getAtoms()) {
                if ((groundAtom instanceof RandomVariableAtom) && (blockIndex = constraintBlockerTermStore.getBlockIndex((RandomVariableAtom) groundAtom)) != -1) {
                    hashSet2.add((RandomVariableAtom) groundAtom);
                    hashSet3.add(constraintBlockerTermStore.get(blockIndex));
                }
            }
            if (hashSet3.size() == 0) {
                i3--;
            } else {
                RandomVariableAtom[] randomVariableAtomArr = new RandomVariableAtom[hashSet3.size()];
                WeightedGroundRule[] weightedGroundRuleArr = new WeightedGroundRule[hashSet3.size()];
                boolean[] zArr = new boolean[hashSet3.size()];
                int i4 = 0;
                for (ConstraintBlockerTerm constraintBlockerTerm : hashSet3) {
                    randomVariableAtomArr[i4] = constraintBlockerTerm.getAtoms();
                    zArr[i4] = constraintBlockerTerm.getExactlyOne();
                    weightedGroundRuleArr[i4] = constraintBlockerTerm.getIncidentGRs();
                    i4++;
                }
                if (RandUtils.nextDouble() <= this.noise) {
                    i = RandUtils.nextInt(randomVariableAtomArr.length);
                    int length = randomVariableAtomArr[i].length;
                    do {
                        i2 = RandUtils.nextInt(length);
                        if (!zArr[i]) {
                            break;
                        }
                    } while (randomVariableAtomArr[i][i2].getValue() == 1.0d);
                    if (randomVariableAtomArr[i][i2].getValue() == 1.0d) {
                        i2 = -1;
                    }
                } else {
                    i = -1;
                    i2 = -1;
                    double d = Double.POSITIVE_INFINITY;
                    for (int i5 = 0; i5 < randomVariableAtomArr.length; i5++) {
                        float[] fArr = new float[randomVariableAtomArr[i5].length];
                        float f = 0.0f;
                        for (int i6 = 0; i6 < randomVariableAtomArr[i5].length; i6++) {
                            fArr[i6] = randomVariableAtomArr[i5][i6].getValue();
                            f += fArr[i6];
                        }
                        int length2 = randomVariableAtomArr[i5].length;
                        if (!zArr[i5]) {
                            length2++;
                        }
                        for (int i7 = 0; i7 < length2; i7++) {
                            for (int i8 = 0; i8 < randomVariableAtomArr[i5].length; i8++) {
                                if (i8 == i7) {
                                    randomVariableAtomArr[i5][i8].setValue(1.0f);
                                } else {
                                    randomVariableAtomArr[i5][i8].setValue(0.0f);
                                }
                            }
                            double d2 = 0.0d;
                            for (WeightedGroundLogicalRule weightedGroundLogicalRule : weightedGroundRuleArr[i5]) {
                                d2 += weightedGroundLogicalRule.getWeight() * weightedGroundLogicalRule.getIncompatibility();
                            }
                            if (d2 < d) {
                                d = d2;
                                i = i5;
                                i2 = i7;
                                if (MathUtils.isZero(d)) {
                                    break;
                                }
                            }
                        }
                        for (int i9 = 0; i9 < randomVariableAtomArr[i5].length; i9++) {
                            randomVariableAtomArr[i5][i9].setValue(fArr[i9]);
                        }
                        if (MathUtils.isZero(d)) {
                            break;
                        }
                    }
                }
                for (int i10 = 0; i10 < randomVariableAtomArr[i].length; i10++) {
                    if (i10 == i2) {
                        randomVariableAtomArr[i][i10].setValue(1.0f);
                    } else {
                        randomVariableAtomArr[i][i10].setValue(0.0f);
                    }
                }
                for (WeightedGroundLogicalRule weightedGroundLogicalRule2 : weightedGroundRuleArr[i]) {
                    if (weightedGroundLogicalRule2.getIncompatibility() > 0.0d) {
                        hashSet.add(weightedGroundLogicalRule2);
                    } else {
                        hashSet.remove(weightedGroundLogicalRule2);
                    }
                }
                if (i3 % 5000 == 0) {
                    log.info("Flip {}, Total weighted incompatibility: {}, Infeasbility norm: {}", Integer.valueOf(i3), Double.valueOf(GroundRules.getTotalWeightedIncompatibility(constraintBlockerTermStore.getGroundRuleStore().getCompatibilityRules())), Double.valueOf(GroundRules.getInfeasibilityNorm(constraintBlockerTermStore.getGroundRuleStore().getConstraintRules())));
                }
            }
            i3++;
        }
    }

    private Object selectAtRandom(Collection<? extends Object> collection) {
        int i = 0;
        int nextInt = RandUtils.nextInt(collection.size());
        for (Object obj : collection) {
            int i2 = i;
            i++;
            if (i2 == nextInt) {
                return obj;
            }
        }
        return null;
    }

    @Override // org.linqs.psl.reasoner.Reasoner
    public void close() {
    }
}
