package org.linqs.psl.reasoner.term;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.linqs.psl.config.Options;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.UnweightedGroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.function.ConstraintTerm;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.function.FunctionTerm;
import org.linqs.psl.reasoner.function.GeneralFunction;
import org.linqs.psl.reasoner.term.ReasonerLocalVariable;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;

/* loaded from: input_file:org/linqs/psl/reasoner/term/HyperplaneTermGenerator.class */
public abstract class HyperplaneTermGenerator<T extends ReasonerTerm, V extends ReasonerLocalVariable> implements TermGenerator<T, V> {
    private static final Logger log = Logger.getLogger(HyperplaneTermGenerator.class);
    protected boolean invertNegativeWeight = Options.HYPERPLANE_TG_INVERT_NEGATIVE_WEIGHTS.getBoolean();
    protected boolean addDeterTerms = Options.HYPERPLANE_TG_ADD_DETER.getBoolean();
    protected boolean collectiveDeter = Options.HYPERPLANE_TG_DETER_COLLECTIVE.getBoolean();
    protected float deterWeight = Options.HYPERPLANE_TG_DETER_WEIGHT.getFloat();
    protected float deterEpsilon = Options.HYPERPLANE_TG_DETER_EPSILON.getFloat();
    protected float deterConstant = Options.HYPERPLANE_TG_DETER_CONSTANT.getFloat();
    protected boolean mergeConstants;

    /* loaded from: input_file:org/linqs/psl/reasoner/term/HyperplaneTermGenerator$GeneratorWorker.class */
    private class GeneratorWorker extends Parallel.Worker<GroundRule> {
        private final TermStore<T, V> termStore;
        private final List<T> newTerms = new ArrayList(2);
        private final List<Hyperplane> newHyperplane = new ArrayList(1);

        public GeneratorWorker(TermStore<T, V> termStore) {
            this.termStore = termStore;
        }

        public Object clone() {
            return new GeneratorWorker(this.termStore);
        }

        @Override // org.linqs.psl.util.Parallel.Worker
        public void work(long j, GroundRule groundRule) {
            this.newTerms.clear();
            this.newHyperplane.clear();
            if (!((groundRule instanceof WeightedGroundRule) && ((double) ((WeightedGroundRule) groundRule).getWeight()) < 0.0d)) {
                HyperplaneTermGenerator.this.createTerm(groundRule, this.termStore, this.newTerms, this.newHyperplane);
                Iterator<T> it = this.newTerms.iterator();
                while (it.hasNext()) {
                    this.termStore.add(groundRule, it.next(), this.newHyperplane.get(0));
                }
                this.newTerms.clear();
                this.newHyperplane.clear();
                return;
            }
            if (HyperplaneTermGenerator.this.invertNegativeWeight) {
                Iterator<GroundRule> it2 = groundRule.negate().iterator();
                while (it2.hasNext()) {
                    HyperplaneTermGenerator.this.createTerm(it2.next(), this.termStore, this.newTerms, this.newHyperplane);
                    Iterator<T> it3 = this.newTerms.iterator();
                    while (it3.hasNext()) {
                        this.termStore.add(groundRule, it3.next(), this.newHyperplane.get(0));
                    }
                    this.newTerms.clear();
                    this.newHyperplane.clear();
                }
            }
        }
    }

    public HyperplaneTermGenerator(boolean z) {
        this.mergeConstants = z;
    }

    @Override // org.linqs.psl.reasoner.term.TermGenerator
    public long generateTerms(GroundRuleStore groundRuleStore, TermStore<T, V> termStore) {
        long size = termStore.size();
        termStore.ensureCapacity(size + groundRuleStore.size());
        HashSet<WeightedRule> hashSet = new HashSet();
        for (GroundRule groundRule : groundRuleStore.getGroundRules()) {
            if (groundRule instanceof WeightedGroundRule) {
                hashSet.add((WeightedRule) groundRule.getRule());
            }
        }
        for (WeightedRule weightedRule : hashSet) {
            if (weightedRule.getWeight() < 0.0d) {
                log.warn("Found a rule with a negative weight, but config says not to invert it... skipping: " + weightedRule);
            }
        }
        Parallel.foreach(groundRuleStore.getGroundRules(), new GeneratorWorker(termStore));
        return termStore.size() - size;
    }

    public int createTerm(GroundRule groundRule, TermStore<T, V> termStore, Collection<T> collection, Collection<Hyperplane> collection2) {
        Hyperplane<V> processHyperplane;
        int createLinearConstraintTerm;
        if (groundRule instanceof WeightedGroundRule) {
            GeneralFunction functionDefinition = ((WeightedGroundRule) groundRule).getFunctionDefinition(this.mergeConstants);
            processHyperplane = processHyperplane(functionDefinition, termStore);
            if (processHyperplane == null) {
                return 0;
            }
            createLinearConstraintTerm = createLossTerm(collection, termStore, functionDefinition.isNonNegative(), functionDefinition.isSquared(), groundRule, processHyperplane);
        } else {
            if (!(groundRule instanceof UnweightedGroundRule)) {
                throw new IllegalArgumentException("Unsupported ground rule: " + groundRule);
            }
            ConstraintTerm constraintDefinition = ((UnweightedGroundRule) groundRule).getConstraintDefinition(this.mergeConstants);
            processHyperplane = processHyperplane(constraintDefinition.getFunction(), termStore);
            if (processHyperplane == null) {
                return 0;
            }
            processHyperplane.setConstant(constraintDefinition.getValue() + processHyperplane.getConstant());
            createLinearConstraintTerm = createLinearConstraintTerm(collection, termStore, groundRule, processHyperplane, constraintDefinition.getComparator());
        }
        if (collection2 != null && createLinearConstraintTerm > 0) {
            collection2.add(processHyperplane);
        }
        return createLinearConstraintTerm;
    }

    private Hyperplane<V> processHyperplane(GeneralFunction generalFunction, TermStore<T, V> termStore) {
        Hyperplane<V> hyperplane = new Hyperplane<>(getLocalVariableType(), generalFunction.size(), (-1.0f) * generalFunction.getConstant());
        for (int i = 0; i < generalFunction.size(); i++) {
            float coefficient = generalFunction.getCoefficient(i);
            FunctionTerm term = generalFunction.getTerm(i);
            if ((term instanceof RandomVariableAtom) && ((RandomVariableAtom) term).getPredicate().isFixedMirror()) {
                hyperplane.setConstant(hyperplane.getConstant() - (coefficient * term.getValue()));
                hyperplane.addIntegratedRVA((RandomVariableAtom) term, -coefficient);
            } else if ((term instanceof RandomVariableAtom) || (!this.mergeConstants && (term instanceof ObservedAtom))) {
                V createLocalVariable = termStore.createLocalVariable((GroundAtom) term);
                int indexOfVariable = hyperplane.indexOfVariable(createLocalVariable);
                if (indexOfVariable == -1) {
                    hyperplane.addTerm(createLocalVariable, coefficient);
                } else {
                    if (generalFunction.isNonNegative() && !MathUtils.signsMatch(hyperplane.getCoefficient(indexOfVariable), coefficient)) {
                        return null;
                    }
                    hyperplane.appendCoefficient(indexOfVariable, coefficient);
                }
            } else {
                if (!term.isConstant()) {
                    throw new IllegalArgumentException("Unexpected summand: " + generalFunction + "[" + i + "] (" + term + ").");
                }
                hyperplane.setConstant(hyperplane.getConstant() - (coefficient * term.getValue()));
            }
        }
        if (hyperplane.size() == 0) {
            return null;
        }
        return hyperplane;
    }

    public abstract Class<V> getLocalVariableType();

    public abstract int createLossTerm(Collection<T> collection, TermStore<T, V> termStore, boolean z, boolean z2, GroundRule groundRule, Hyperplane<V> hyperplane);

    public abstract int createLinearConstraintTerm(Collection<T> collection, TermStore<T, V> termStore, GroundRule groundRule, Hyperplane<V> hyperplane, FunctionComparator functionComparator);
}
