package org.linqs.psl.model.rule.arithmetic.expression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.cli.HelpFormatter;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.formula.Conjunction;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Cardinality;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Coefficient;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.ConstantNumber;
import org.linqs.psl.model.term.Term;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.util.HashCode;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/model/rule/arithmetic/expression/ArithmeticRuleExpression.class */
public class ArithmeticRuleExpression implements Serializable {
    protected final List<Coefficient> coefficients;
    protected final List<SummationAtomOrAtom> atoms;
    protected final FunctionComparator comparator;
    protected final Coefficient constant;
    protected final Set<Variable> vars;
    protected final Map<SummationVariable, SummationAtom> summationMapping;
    private int hash;

    public ArithmeticRuleExpression(List<Coefficient> list, List<SummationAtomOrAtom> list2, FunctionComparator functionComparator, Coefficient coefficient) {
        this(list, list2, functionComparator, coefficient, false);
    }

    public ArithmeticRuleExpression(List<Coefficient> list, List<SummationAtomOrAtom> list2, FunctionComparator functionComparator, Coefficient coefficient, boolean z) {
        this.coefficients = Collections.unmodifiableList(list);
        this.atoms = Collections.unmodifiableList(list2);
        this.comparator = functionComparator;
        this.constant = coefficient;
        HashSet<Variable> hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        HashMap hashMap = new HashMap();
        if (list2.size() == 0) {
            throw new IllegalArgumentException("Cannot have an arithmetic rule without atoms.");
        }
        for (SummationAtomOrAtom summationAtomOrAtom : getAtoms()) {
            if (summationAtomOrAtom instanceof SummationAtom) {
                for (SummationVariableOrTerm summationVariableOrTerm : ((SummationAtom) summationAtomOrAtom).getArguments()) {
                    if (summationVariableOrTerm instanceof Variable) {
                        hashSet.add((Variable) summationVariableOrTerm);
                    } else if (!(summationVariableOrTerm instanceof SummationVariable)) {
                        continue;
                    } else {
                        if (hashMap.containsKey((SummationVariable) summationVariableOrTerm)) {
                            throw new IllegalArgumentException("Each summation variable in an ArithmeticRuleExpression must be unique.");
                        }
                        hashSet2.add(((SummationVariable) summationVariableOrTerm).getVariable().getName());
                        hashMap.put((SummationVariable) summationVariableOrTerm, (SummationAtom) summationAtomOrAtom);
                    }
                }
            } else {
                for (Term term : ((Atom) summationAtomOrAtom).getArguments()) {
                    if (term instanceof Variable) {
                        hashSet.add((Variable) term);
                    }
                }
            }
        }
        for (Variable variable : hashSet) {
            if (hashSet2.contains(variable.getName())) {
                throw new IllegalArgumentException(String.format("Summation variable (+%s) cannot be used as a normal variable (%s).", variable.getName(), variable.getName()));
            }
        }
        if (!z) {
            for (Coefficient coefficient2 : list) {
                if (coefficient2 instanceof Cardinality) {
                    String name = ((Cardinality) coefficient2).getSummationVariable().getVariable().getName();
                    if (!hashSet2.contains(name)) {
                        throw new IllegalArgumentException(String.format("Cannot use variable (%s) in cardinality. Only summation variables can be used in cardinality.", name));
                    }
                }
            }
        }
        this.vars = Collections.unmodifiableSet(hashSet);
        this.summationMapping = Collections.unmodifiableMap(hashMap);
        this.hash = HashCode.build(HashCode.build(functionComparator), coefficient);
        Iterator<Coefficient> it = list.iterator();
        while (it.hasNext()) {
            this.hash = HashCode.build(this.hash, it.next());
        }
        Iterator<SummationAtomOrAtom> it2 = list2.iterator();
        while (it2.hasNext()) {
            this.hash = HashCode.build(this.hash, it2.next());
        }
    }

    public int hashCode() {
        return this.hash;
    }

    public List<Coefficient> getAtomCoefficients() {
        return this.coefficients;
    }

    public List<SummationAtomOrAtom> getAtoms() {
        return this.atoms;
    }

    public FunctionComparator getComparator() {
        return this.comparator;
    }

    public Coefficient getFinalCoefficient() {
        return this.constant;
    }

    public Set<Variable> getVariables() {
        return this.vars;
    }

    public Set<SummationVariable> getSummationVariables() {
        return this.summationMapping.keySet();
    }

    public Map<SummationVariable, SummationAtom> getSummationMapping() {
        return this.summationMapping;
    }

    public boolean looksLikeFunctionalConstraint() {
        return FunctionComparator.EQ.equals(this.comparator) && this.atoms.size() == 1 && (this.atoms.get(0) instanceof SummationAtom) && this.coefficients.size() == 1 && (this.coefficients.get(0) instanceof ConstantNumber) && MathUtils.equals(1.0f, this.coefficients.get(0).getValue(null)) && (this.constant instanceof ConstantNumber) && MathUtils.equals(1.0f, this.constant.getValue(null));
    }

    public boolean looksLikeNegativePrior() {
        return this.summationMapping.size() == 0 && this.atoms.size() == 1 && FunctionComparator.EQ.equals(this.comparator) && (this.constant instanceof ConstantNumber) && MathUtils.isZero(this.constant.getValue(null));
    }

    public Formula getQueryFormula() {
        ArrayList arrayList = new ArrayList();
        for (SummationAtomOrAtom summationAtomOrAtom : this.atoms) {
            if (summationAtomOrAtom instanceof SummationAtom) {
                arrayList.add(((SummationAtom) summationAtomOrAtom).getQueryAtom());
            } else {
                arrayList.add((Atom) summationAtomOrAtom);
            }
        }
        return arrayList.size() == 1 ? (Formula) arrayList.get(0) : new Conjunction((Formula[]) arrayList.toArray(new Formula[0]));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.coefficients.size() > 0) {
            for (int i = 0; i < this.coefficients.size(); i++) {
                if (i != 0) {
                    sb.append(" + ");
                }
                sb.append(this.coefficients.get(i));
                sb.append(" * ");
                sb.append(this.atoms.get(i));
            }
        } else {
            sb.append("0.0");
        }
        sb.append(HelpFormatter.DEFAULT_LONG_OPT_SEPARATOR);
        sb.append(this.comparator);
        sb.append(HelpFormatter.DEFAULT_LONG_OPT_SEPARATOR);
        sb.append(this.constant);
        return sb.toString();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        ArithmeticRuleExpression arithmeticRuleExpression = (ArithmeticRuleExpression) obj;
        if (this.hash != arithmeticRuleExpression.hash || this.comparator != arithmeticRuleExpression.comparator || this.atoms.size() != arithmeticRuleExpression.atoms.size() || !this.constant.equals(arithmeticRuleExpression.constant)) {
            return false;
        }
        for (int i = 0; i < this.atoms.size(); i++) {
            int indexOf = arithmeticRuleExpression.atoms.indexOf(this.atoms.get(i));
            if (indexOf == -1 || !this.atoms.get(i).equals(arithmeticRuleExpression.atoms.get(indexOf)) || !this.coefficients.get(i).equals(arithmeticRuleExpression.coefficients.get(indexOf))) {
                return false;
            }
        }
        return true;
    }
}
