package org.linqs.psl.parser;

import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import org.antlr.v4.runtime.BailErrorStrategy;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonToken;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.RecognitionException;
import org.antlr.v4.runtime.Recognizer;
import org.antlr.v4.runtime.misc.Interval;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.apache.commons.configuration2.tree.DefaultExpressionEngineSymbols;
import org.apache.commons.lang3.StringUtils;
import org.linqs.psl.database.DataStore;
import org.linqs.psl.database.loading.Inserter;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.atom.QueryAtom;
import org.linqs.psl.model.formula.Conjunction;
import org.linqs.psl.model.formula.Disjunction;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.formula.Implication;
import org.linqs.psl.model.formula.Negation;
import org.linqs.psl.model.predicate.GroundingOnlyPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.arithmetic.UnweightedArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.WeightedArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.expression.ArithmeticRuleExpression;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtom;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtomOrAtom;
import org.linqs.psl.model.rule.arithmetic.expression.SummationVariable;
import org.linqs.psl.model.rule.arithmetic.expression.SummationVariableOrTerm;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Add;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Cardinality;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Coefficient;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.ConstantNumber;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Divide;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Max;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Min;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Multiply;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Subtract;
import org.linqs.psl.model.rule.logical.UnweightedLogicalRule;
import org.linqs.psl.model.rule.logical.WeightedLogicalRule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.StringAttribute;
import org.linqs.psl.model.term.Term;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.parser.antlr.PSLBaseVisitor;
import org.linqs.psl.parser.antlr.PSLLexer;
import org.linqs.psl.parser.antlr.PSLParser;
import org.linqs.psl.reasoner.function.FunctionComparator;

/* loaded from: input_file:org/linqs/psl/parser/ModelLoader.class */
public class ModelLoader extends PSLBaseVisitor<Object> {
    private final DataStore data;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/parser/ModelLoader$ArithmeticCoefficientOperand.class */
    public static class ArithmeticCoefficientOperand {
        SummationAtomOrAtom atom;
        Coefficient coefficient;

        private ArithmeticCoefficientOperand() {
            this.atom = null;
            this.coefficient = null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/parser/ModelLoader$FilterClause.class */
    public static class FilterClause {
        SummationVariable v;
        Formula f;

        private FilterClause() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/parser/ModelLoader$LinearArithmeticExpression.class */
    public static class LinearArithmeticExpression {
        public List<Coefficient> coefficients = new LinkedList();
        public List<SummationAtomOrAtom> atoms = new LinkedList();
        public Coefficient nonAtomCoefficient = null;
    }

    public static RulePartial loadRulePartial(DataStore dataStore, String str) {
        try {
            try {
                return new ModelLoader(dataStore).visitPslRulePartial(getParser(str).pslRulePartial());
            } catch (ParseCancellationException e) {
                throw ((RuntimeException) e.getCause());
            }
        } catch (IOException e2) {
            throw new RuntimeException("Failed to lex rule partial.", e2);
        }
    }

    public static Rule loadRule(DataStore dataStore, String str) {
        int i = 0;
        Rule rule = null;
        for (Rule rule2 : load(dataStore, new StringReader(str)).getRules()) {
            if (i == 0) {
                rule = rule2;
            }
            i++;
        }
        if (i != 1) {
            throw new IllegalArgumentException(String.format("Expected 1 rule, found %d.", Integer.valueOf(i)));
        }
        return rule;
    }

    public static Model load(DataStore dataStore, String str) {
        return load(dataStore, new StringReader(str));
    }

    public static Model load(DataStore dataStore, Reader reader) {
        try {
            PSLParser parser = getParser(reader);
            try {
                return new ModelLoader(dataStore).visitProgram(parser.program(), parser);
            } catch (ParseCancellationException e) {
                throw ((RuntimeException) e.getCause());
            }
        } catch (IOException e2) {
            throw new RuntimeException("Failed to lex rule partial.", e2);
        }
    }

    private static PSLParser getParser(Reader reader) throws IOException {
        PSLLexer pSLLexer = new PSLLexer(CharStreams.fromReader(reader));
        pSLLexer.addErrorListener(new BaseErrorListener() { // from class: org.linqs.psl.parser.ModelLoader.1
            @Override // org.antlr.v4.runtime.BaseErrorListener, org.antlr.v4.runtime.ANTLRErrorListener
            public void syntaxError(Recognizer<?, ?> recognizer, Object obj, int i, int i2, String str, RecognitionException recognitionException) throws ParseCancellationException {
                throw new ParseCancellationException("line " + i + ":" + i2 + " " + str, recognitionException);
            }
        });
        PSLParser pSLParser = new PSLParser(new CommonTokenStream(pSLLexer));
        pSLParser.setErrorHandler(new BailErrorStrategy());
        return pSLParser;
    }

    private static PSLParser getParser(String str) throws IOException {
        return getParser(new StringReader(str));
    }

    private ModelLoader(DataStore dataStore) {
        this.data = dataStore;
    }

    public Model visitProgram(PSLParser.ProgramContext programContext, PSLParser pSLParser) {
        Model model = new Model();
        for (PSLParser.PslRuleContext pslRuleContext : programContext.pslRule()) {
            try {
                model.addRule((Rule) visit(pslRuleContext));
            } catch (RuntimeException e) {
                throw new RuntimeException("Failed to compile rule: [" + pSLParser.getTokenStream().getText(pslRuleContext) + DefaultExpressionEngineSymbols.DEFAULT_ATTRIBUTE_END, e);
            }
        }
        return model;
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public RulePartial visitPslRulePartial(PSLParser.PslRulePartialContext pslRulePartialContext) {
        if (pslRulePartialContext == null || pslRulePartialContext.getChildCount() < 2) {
            throw new IllegalStateException();
        }
        Object visit = visit(pslRulePartialContext.getChild(0));
        if (!(visit instanceof Rule) && !(visit instanceof Formula) && !(visit instanceof ArithmeticRuleExpression)) {
            throw new IllegalStateException();
        }
        if (pslRulePartialContext.getChildCount() == 2) {
            return new RulePartial(visit);
        }
        if (!(visit instanceof ArithmeticRuleExpression)) {
            throw new IllegalStateException();
        }
        HashMap hashMap = new HashMap();
        for (int i = 1; i < pslRulePartialContext.getChildCount() - 1; i++) {
            FilterClause visitFilterClause = visitFilterClause((PSLParser.FilterClauseContext) pslRulePartialContext.getChild(i));
            hashMap.put(visitFilterClause.v, visitFilterClause.f);
        }
        return new RulePartial((ArithmeticRuleExpression) visit, hashMap);
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public WeightedLogicalRule visitWeightedLogicalRule(PSLParser.WeightedLogicalRuleContext weightedLogicalRuleContext) {
        Double visitWeightExpression = visitWeightExpression(weightedLogicalRuleContext.weightExpression());
        Formula visitLogicalRuleExpression = visitLogicalRuleExpression(weightedLogicalRuleContext.logicalRuleExpression());
        Boolean bool = false;
        if (weightedLogicalRuleContext.EXPONENT_EXPRESSION() != null) {
            bool = Boolean.valueOf(weightedLogicalRuleContext.EXPONENT_EXPRESSION().getText().equals("^2"));
        }
        return new WeightedLogicalRule(visitLogicalRuleExpression, visitWeightExpression.doubleValue(), bool.booleanValue());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public UnweightedLogicalRule visitUnweightedLogicalRule(PSLParser.UnweightedLogicalRuleContext unweightedLogicalRuleContext) {
        return new UnweightedLogicalRule(visitLogicalRuleExpression(unweightedLogicalRuleContext.logicalRuleExpression()));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitLogicalRuleExpression(PSLParser.LogicalRuleExpressionContext logicalRuleExpressionContext) {
        if (logicalRuleExpressionContext.logicalDisjunctiveExpression() != null) {
            return visitLogicalDisjunctiveExpression(logicalRuleExpressionContext.logicalDisjunctiveExpression());
        }
        if (logicalRuleExpressionContext.logicalImplicationExpression() != null) {
            return visitLogicalImplicationExpression(logicalRuleExpressionContext.logicalImplicationExpression());
        }
        throw new IllegalStateException();
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitLogicalImplicationExpression(PSLParser.LogicalImplicationExpressionContext logicalImplicationExpressionContext) {
        return new Implication(visitLogicalConjunctiveExpression(logicalImplicationExpressionContext.logicalConjunctiveExpression()), visitLogicalDisjunctiveExpression(logicalImplicationExpressionContext.logicalDisjunctiveExpression()));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitLogicalDisjunctiveExpression(PSLParser.LogicalDisjunctiveExpressionContext logicalDisjunctiveExpressionContext) {
        return logicalDisjunctiveExpressionContext.getChildCount() == 1 ? visitLogicalDisjunctiveValue(logicalDisjunctiveExpressionContext.logicalDisjunctiveValue()) : new Disjunction(visitLogicalDisjunctiveExpression(logicalDisjunctiveExpressionContext.logicalDisjunctiveExpression()), visitLogicalDisjunctiveValue(logicalDisjunctiveExpressionContext.logicalDisjunctiveValue())).flatten();
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitLogicalConjunctiveExpression(PSLParser.LogicalConjunctiveExpressionContext logicalConjunctiveExpressionContext) {
        return logicalConjunctiveExpressionContext.getChildCount() == 1 ? visitLogicalConjunctiveValue(logicalConjunctiveExpressionContext.logicalConjunctiveValue()) : new Conjunction(visitLogicalConjunctiveExpression(logicalConjunctiveExpressionContext.logicalConjunctiveExpression()), visitLogicalConjunctiveValue(logicalConjunctiveExpressionContext.logicalConjunctiveValue())).flatten();
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitLogicalDisjunctiveValue(PSLParser.LogicalDisjunctiveValueContext logicalDisjunctiveValueContext) {
        return logicalDisjunctiveValueContext.getChildCount() == 1 ? visitLogicalNegationValue(logicalDisjunctiveValueContext.logicalNegationValue()) : visitLogicalDisjunctiveExpression(logicalDisjunctiveValueContext.logicalDisjunctiveExpression());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitLogicalConjunctiveValue(PSLParser.LogicalConjunctiveValueContext logicalConjunctiveValueContext) {
        return logicalConjunctiveValueContext.getChildCount() == 1 ? visitLogicalNegationValue(logicalConjunctiveValueContext.logicalNegationValue()) : visitLogicalConjunctiveExpression(logicalConjunctiveValueContext.logicalConjunctiveExpression());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitLogicalNegationValue(PSLParser.LogicalNegationValueContext logicalNegationValueContext) {
        return logicalNegationValueContext.getChildCount() == 1 ? visitAtom(logicalNegationValueContext.atom()) : logicalNegationValueContext.getChildCount() == 2 ? new Negation(visitLogicalNegationValue(logicalNegationValueContext.logicalNegationValue())) : visitLogicalNegationValue(logicalNegationValueContext.logicalNegationValue());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public WeightedArithmeticRule visitWeightedArithmeticRule(PSLParser.WeightedArithmeticRuleContext weightedArithmeticRuleContext) {
        Double visitWeightExpression = visitWeightExpression(weightedArithmeticRuleContext.weightExpression());
        ArithmeticRuleExpression visitArithmeticRuleExpression = visitArithmeticRuleExpression(weightedArithmeticRuleContext.arithmeticRuleExpression());
        HashMap hashMap = new HashMap();
        for (int i = 0; i < weightedArithmeticRuleContext.filterClause().size(); i++) {
            FilterClause visitFilterClause = visitFilterClause(weightedArithmeticRuleContext.filterClause(i));
            hashMap.put(visitFilterClause.v, visitFilterClause.f);
        }
        return new WeightedArithmeticRule(visitArithmeticRuleExpression, hashMap, visitWeightExpression.doubleValue(), (weightedArithmeticRuleContext.EXPONENT_EXPRESSION() != null ? Boolean.valueOf(weightedArithmeticRuleContext.EXPONENT_EXPRESSION().getText().equals("^2")) : false).booleanValue());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public UnweightedArithmeticRule visitUnweightedArithmeticRule(PSLParser.UnweightedArithmeticRuleContext unweightedArithmeticRuleContext) {
        ArithmeticRuleExpression visitArithmeticRuleExpression = visitArithmeticRuleExpression(unweightedArithmeticRuleContext.arithmeticRuleExpression());
        HashMap hashMap = new HashMap();
        for (int i = 0; i < unweightedArithmeticRuleContext.filterClause().size(); i++) {
            FilterClause visitFilterClause = visitFilterClause(unweightedArithmeticRuleContext.filterClause(i));
            hashMap.put(visitFilterClause.v, visitFilterClause.f);
        }
        return new UnweightedArithmeticRule(visitArithmeticRuleExpression, hashMap);
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public ArithmeticRuleExpression visitArithmeticRuleExpression(PSLParser.ArithmeticRuleExpressionContext arithmeticRuleExpressionContext) {
        LinearArithmeticExpression visitLinearArithmeticExpression = visitLinearArithmeticExpression((PSLParser.LinearArithmeticExpressionContext) arithmeticRuleExpressionContext.getChild(0));
        FunctionComparator visitArithmeticRuleRelation = visitArithmeticRuleRelation((PSLParser.ArithmeticRuleRelationContext) arithmeticRuleExpressionContext.getChild(1));
        LinearArithmeticExpression visitLinearArithmeticExpression2 = visitLinearArithmeticExpression((PSLParser.LinearArithmeticExpressionContext) arithmeticRuleExpressionContext.getChild(2));
        List<Coefficient> list = visitLinearArithmeticExpression.coefficients;
        List<SummationAtomOrAtom> list2 = visitLinearArithmeticExpression.atoms;
        for (int i = 0; i < visitLinearArithmeticExpression2.atoms.size(); i++) {
            list.add(new Multiply(new ConstantNumber(-1.0f), visitLinearArithmeticExpression2.coefficients.get(i)));
            list2.add(visitLinearArithmeticExpression2.atoms.get(i));
        }
        Coefficient multiply = visitLinearArithmeticExpression.nonAtomCoefficient != null ? new Multiply(new ConstantNumber(-1.0f), visitLinearArithmeticExpression.nonAtomCoefficient) : null;
        if (visitLinearArithmeticExpression2.nonAtomCoefficient != null) {
            multiply = multiply == null ? visitLinearArithmeticExpression2.nonAtomCoefficient : new Add(multiply, visitLinearArithmeticExpression2.nonAtomCoefficient);
        }
        if (multiply == null) {
            multiply = new ConstantNumber(0.0f);
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            list.set(i2, list.get(i2).simplify());
        }
        return new ArithmeticRuleExpression(list, list2, visitArithmeticRuleRelation, multiply.simplify());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public LinearArithmeticExpression visitLinearArithmeticExpression(PSLParser.LinearArithmeticExpressionContext linearArithmeticExpressionContext) {
        if (linearArithmeticExpressionContext.getChildCount() == 1) {
            return visitLinearArithmeticOperand((PSLParser.LinearArithmeticOperandContext) linearArithmeticExpressionContext.getChild(0));
        }
        if (linearArithmeticExpressionContext.getChildCount() != 3) {
            throw new IllegalStateException("Expeciting three children.");
        }
        LinearArithmeticExpression visitLinearArithmeticExpression = visitLinearArithmeticExpression((PSLParser.LinearArithmeticExpressionContext) linearArithmeticExpressionContext.getChild(0));
        boolean booleanValue = visitLinearOperator((PSLParser.LinearOperatorContext) linearArithmeticExpressionContext.getChild(1)).booleanValue();
        LinearArithmeticExpression visitLinearArithmeticOperand = visitLinearArithmeticOperand((PSLParser.LinearArithmeticOperandContext) linearArithmeticExpressionContext.getChild(2));
        for (int i = 0; i < visitLinearArithmeticOperand.atoms.size(); i++) {
            Coefficient coefficient = visitLinearArithmeticOperand.coefficients.get(i);
            if (!booleanValue) {
                coefficient = new Multiply(new ConstantNumber(-1.0f), coefficient);
            }
            visitLinearArithmeticExpression.atoms.add(visitLinearArithmeticOperand.atoms.get(i));
            visitLinearArithmeticExpression.coefficients.add(coefficient);
        }
        Coefficient coefficient2 = visitLinearArithmeticExpression.nonAtomCoefficient != null ? visitLinearArithmeticExpression.nonAtomCoefficient : null;
        if (visitLinearArithmeticOperand.nonAtomCoefficient != null) {
            coefficient2 = coefficient2 == null ? visitLinearArithmeticOperand.nonAtomCoefficient : booleanValue ? new Add(coefficient2, visitLinearArithmeticOperand.nonAtomCoefficient) : new Subtract(coefficient2, visitLinearArithmeticOperand.nonAtomCoefficient);
        }
        visitLinearArithmeticExpression.nonAtomCoefficient = coefficient2;
        return visitLinearArithmeticExpression;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v21, types: [org.linqs.psl.model.rule.arithmetic.expression.coefficient.Coefficient] */
    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public LinearArithmeticExpression visitLinearArithmeticOperand(PSLParser.LinearArithmeticOperandContext linearArithmeticOperandContext) {
        if (linearArithmeticOperandContext.getChildCount() == 3) {
            return visitLinearArithmeticExpression((PSLParser.LinearArithmeticExpressionContext) linearArithmeticOperandContext.getChild(1));
        }
        if (linearArithmeticOperandContext.getChildCount() != 1) {
            throw new IllegalStateException("Expeciting three children.");
        }
        LinearArithmeticExpression linearArithmeticExpression = new LinearArithmeticExpression();
        ArithmeticCoefficientOperand visitArithmeticCoefficientOperand = visitArithmeticCoefficientOperand((PSLParser.ArithmeticCoefficientOperandContext) linearArithmeticOperandContext.getChild(0));
        ConstantNumber constantNumber = new ConstantNumber(1.0f);
        if (visitArithmeticCoefficientOperand.coefficient != null) {
            constantNumber = visitArithmeticCoefficientOperand.coefficient;
        }
        if (visitArithmeticCoefficientOperand.atom != null) {
            linearArithmeticExpression.coefficients.add(constantNumber);
            linearArithmeticExpression.atoms.add(visitArithmeticCoefficientOperand.atom);
        } else {
            linearArithmeticExpression.nonAtomCoefficient = constantNumber;
        }
        return linearArithmeticExpression;
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public ArithmeticCoefficientOperand visitArithmeticCoefficientOperand(PSLParser.ArithmeticCoefficientOperandContext arithmeticCoefficientOperandContext) {
        ArithmeticCoefficientOperand arithmeticCoefficientOperand = new ArithmeticCoefficientOperand();
        if (arithmeticCoefficientOperandContext.getChildCount() == 1 && (arithmeticCoefficientOperandContext.getChild(0).getPayload() instanceof PSLParser.CoefficientExpressionContext)) {
            arithmeticCoefficientOperand.coefficient = visitCoefficientExpression((PSLParser.CoefficientExpressionContext) arithmeticCoefficientOperandContext.getChild(0));
            return arithmeticCoefficientOperand;
        }
        int i = 0;
        if (arithmeticCoefficientOperandContext.getChild(0).getPayload() instanceof PSLParser.CoefficientExpressionContext) {
            arithmeticCoefficientOperand.coefficient = visitCoefficientExpression((PSLParser.CoefficientExpressionContext) arithmeticCoefficientOperandContext.getChild(0));
            i = arithmeticCoefficientOperandContext.getChild(1).getPayload() instanceof CommonToken ? 2 : 1;
        }
        arithmeticCoefficientOperand.atom = (SummationAtomOrAtom) visit(arithmeticCoefficientOperandContext.getChild(i));
        if (arithmeticCoefficientOperandContext.getChildCount() > i + 1) {
            Coefficient visitCoefficientExpression = visitCoefficientExpression((PSLParser.CoefficientExpressionContext) arithmeticCoefficientOperandContext.getChild(i + 2));
            if (arithmeticCoefficientOperand.coefficient == null) {
                arithmeticCoefficientOperand.coefficient = new Divide(new ConstantNumber(1.0f), visitCoefficientExpression);
            } else {
                arithmeticCoefficientOperand.coefficient = new Divide(arithmeticCoefficientOperand.coefficient, visitCoefficientExpression);
            }
        }
        return arithmeticCoefficientOperand;
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public SummationAtomOrAtom visitSummationAtom(PSLParser.SummationAtomContext summationAtomContext) {
        Predicate visitPredicate = visitPredicate(summationAtomContext.predicate());
        SummationVariableOrTerm[] summationVariableOrTermArr = new SummationVariableOrTerm[(summationAtomContext.getChildCount() / 2) - 1];
        for (int i = 1; i < summationAtomContext.getChildCount() / 2; i++) {
            if (summationAtomContext.getChild(i * 2).getPayload() instanceof PSLParser.SummationVariableContext) {
                summationVariableOrTermArr[i - 1] = visitSummationVariable((PSLParser.SummationVariableContext) summationAtomContext.getChild(i * 2).getPayload());
            } else {
                if (!(summationAtomContext.getChild(i * 2).getPayload() instanceof PSLParser.TermContext)) {
                    throw new IllegalStateException();
                }
                summationVariableOrTermArr[i - 1] = (Term) visit(summationAtomContext.getChild(i * 2));
            }
        }
        boolean z = false;
        int length = summationVariableOrTermArr.length;
        int i2 = 0;
        while (true) {
            if (i2 >= length) {
                break;
            }
            if (summationVariableOrTermArr[i2] instanceof SummationVariable) {
                z = true;
                break;
            }
            i2++;
        }
        if (z) {
            return new SummationAtom(visitPredicate, summationVariableOrTermArr);
        }
        Term[] termArr = new Term[summationVariableOrTermArr.length];
        for (int i3 = 0; i3 < termArr.length; i3++) {
            termArr[i3] = (Term) summationVariableOrTermArr[i3];
        }
        return new QueryAtom(visitPredicate, termArr);
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public SummationVariable visitSummationVariable(PSLParser.SummationVariableContext summationVariableContext) {
        return new SummationVariable(summationVariableContext.IDENTIFIER().getText());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Coefficient visitCoefficientExpression(PSLParser.CoefficientExpressionContext coefficientExpressionContext) {
        return visitCoefficientAdditiveExpression((PSLParser.CoefficientAdditiveExpressionContext) coefficientExpressionContext.getChild(0));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Coefficient visitCoefficientAdditiveExpression(PSLParser.CoefficientAdditiveExpressionContext coefficientAdditiveExpressionContext) {
        if (coefficientAdditiveExpressionContext.getChildCount() == 1) {
            return visitCoefficientMultiplicativeExpression((PSLParser.CoefficientMultiplicativeExpressionContext) coefficientAdditiveExpressionContext.getChild(0));
        }
        Coefficient visitCoefficientAdditiveExpression = visitCoefficientAdditiveExpression((PSLParser.CoefficientAdditiveExpressionContext) coefficientAdditiveExpressionContext.getChild(0));
        Coefficient visitCoefficientMultiplicativeExpression = visitCoefficientMultiplicativeExpression((PSLParser.CoefficientMultiplicativeExpressionContext) coefficientAdditiveExpressionContext.getChild(2));
        return coefficientAdditiveExpressionContext.PLUS() != null ? new Add(visitCoefficientAdditiveExpression, visitCoefficientMultiplicativeExpression) : new Subtract(visitCoefficientAdditiveExpression, visitCoefficientMultiplicativeExpression);
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Coefficient visitCoefficientMultiplicativeExpression(PSLParser.CoefficientMultiplicativeExpressionContext coefficientMultiplicativeExpressionContext) {
        if (coefficientMultiplicativeExpressionContext.getChildCount() == 1) {
            return visitCoefficient((PSLParser.CoefficientContext) coefficientMultiplicativeExpressionContext.getChild(0));
        }
        Coefficient visitCoefficientMultiplicativeExpression = visitCoefficientMultiplicativeExpression((PSLParser.CoefficientMultiplicativeExpressionContext) coefficientMultiplicativeExpressionContext.getChild(0));
        Coefficient visitCoefficient = visitCoefficient((PSLParser.CoefficientContext) coefficientMultiplicativeExpressionContext.getChild(2));
        return coefficientMultiplicativeExpressionContext.MULT() != null ? new Multiply(visitCoefficientMultiplicativeExpression, visitCoefficient) : new Divide(visitCoefficientMultiplicativeExpression, visitCoefficient);
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Coefficient visitCoefficient(PSLParser.CoefficientContext coefficientContext) {
        return coefficientContext.number() != null ? new ConstantNumber(visitNumber(coefficientContext.number()).floatValue()) : coefficientContext.getChildCount() == 3 ? visitCoefficientExpression((PSLParser.CoefficientExpressionContext) coefficientContext.getChild(1)) : visitCoefficientOperator((PSLParser.CoefficientOperatorContext) coefficientContext.getChild(0));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Coefficient visitCoefficientOperator(PSLParser.CoefficientOperatorContext coefficientOperatorContext) {
        return coefficientOperatorContext.getChildCount() == 3 ? new Cardinality(new SummationVariable(coefficientOperatorContext.variable().getText())) : visitCoefficientFunction((PSLParser.CoefficientFunctionContext) coefficientOperatorContext.getChild(0));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Coefficient visitCoefficientFunction(PSLParser.CoefficientFunctionContext coefficientFunctionContext) {
        Coefficient visitCoefficientExpression = visitCoefficientExpression((PSLParser.CoefficientExpressionContext) coefficientFunctionContext.getChild(2));
        Coefficient visitCoefficientExpression2 = visitCoefficientExpression((PSLParser.CoefficientExpressionContext) coefficientFunctionContext.getChild(4));
        return coefficientFunctionContext.coefficientFunctionOperator().MAX() != null ? new Max(visitCoefficientExpression, visitCoefficientExpression2) : new Min(visitCoefficientExpression, visitCoefficientExpression2);
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public FunctionComparator visitArithmeticRuleRelation(PSLParser.ArithmeticRuleRelationContext arithmeticRuleRelationContext) {
        if (arithmeticRuleRelationContext.EQUAL() != null) {
            return FunctionComparator.EQ;
        }
        if (arithmeticRuleRelationContext.LESS_THAN_EQUAL() != null) {
            return FunctionComparator.LTE;
        }
        if (arithmeticRuleRelationContext.GREATER_THAN_EQUAL() != null) {
            return FunctionComparator.GTE;
        }
        throw new IllegalStateException();
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Boolean visitLinearOperator(PSLParser.LinearOperatorContext linearOperatorContext) {
        if (linearOperatorContext.PLUS() != null) {
            return true;
        }
        if (linearOperatorContext.MINUS() != null) {
            return false;
        }
        throw new IllegalStateException();
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public FilterClause visitFilterClause(PSLParser.FilterClauseContext filterClauseContext) {
        FilterClause filterClause = new FilterClause();
        filterClause.v = new SummationVariable(filterClauseContext.variable().getText());
        filterClause.f = visitBooleanExpression(filterClauseContext.booleanExpression());
        return filterClause;
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitBooleanValue(PSLParser.BooleanValueContext booleanValueContext) {
        return booleanValueContext.logicalNegationValue() != null ? visitLogicalNegationValue(booleanValueContext.logicalNegationValue()) : visitBooleanExpression(booleanValueContext.booleanExpression());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitBooleanConjunctiveExpression(PSLParser.BooleanConjunctiveExpressionContext booleanConjunctiveExpressionContext) {
        return booleanConjunctiveExpressionContext.getChildCount() == 1 ? visitBooleanValue(booleanConjunctiveExpressionContext.booleanValue()) : new Conjunction(visitBooleanConjunctiveExpression(booleanConjunctiveExpressionContext.booleanConjunctiveExpression()), visitBooleanValue(booleanConjunctiveExpressionContext.booleanValue()));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitBooleanDisjunctiveExpression(PSLParser.BooleanDisjunctiveExpressionContext booleanDisjunctiveExpressionContext) {
        return booleanDisjunctiveExpressionContext.getChildCount() == 1 ? visitBooleanConjunctiveExpression(booleanDisjunctiveExpressionContext.booleanConjunctiveExpression()) : new Disjunction(visitBooleanDisjunctiveExpression(booleanDisjunctiveExpressionContext.booleanDisjunctiveExpression()), visitBooleanConjunctiveExpression(booleanDisjunctiveExpressionContext.booleanConjunctiveExpression()));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Formula visitBooleanExpression(PSLParser.BooleanExpressionContext booleanExpressionContext) {
        return visitBooleanDisjunctiveExpression(booleanExpressionContext.booleanDisjunctiveExpression());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Double visitWeightExpression(PSLParser.WeightExpressionContext weightExpressionContext) {
        return Double.valueOf(Double.parseDouble(weightExpressionContext.number().getText()));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Atom visitAtom(PSLParser.AtomContext atomContext) {
        GroundingOnlyPredicate groundingOnlyPredicate;
        if (atomContext.predicate() != null) {
            Predicate visitPredicate = visitPredicate(atomContext.predicate());
            Term[] termArr = new Term[atomContext.term().size()];
            for (int i = 0; i < termArr.length; i++) {
                termArr[i] = (Term) visit(atomContext.term(i));
            }
            return new QueryAtom(visitPredicate, termArr);
        }
        if (atomContext.termOperator() == null) {
            throw new IllegalStateException();
        }
        if (atomContext.termOperator().notEqual() != null) {
            groundingOnlyPredicate = GroundingOnlyPredicate.NotEqual;
        } else if (atomContext.termOperator().termEqual() != null) {
            groundingOnlyPredicate = GroundingOnlyPredicate.Equal;
        } else {
            if (atomContext.termOperator().nonSymmetric() == null) {
                throw new IllegalStateException();
            }
            groundingOnlyPredicate = GroundingOnlyPredicate.NonSymmetric;
        }
        return new QueryAtom(groundingOnlyPredicate, (Term) visit(atomContext.term(0)), (Term) visit(atomContext.term(1)));
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Predicate visitPredicate(PSLParser.PredicateContext predicateContext) {
        Predicate predicate = Predicate.get(predicateContext.IDENTIFIER().getText());
        if (predicate != null) {
            return predicate;
        }
        throw new IllegalStateException("Undefined predicate " + predicateContext.IDENTIFIER().getText());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Variable visitVariable(PSLParser.VariableContext variableContext) {
        return new Variable(variableContext.IDENTIFIER().getText());
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Constant visitConstant(PSLParser.ConstantContext constantContext) {
        String text = constantContext.start.getInputStream().getText(new Interval(constantContext.start.getStartIndex(), constantContext.stop.getStopIndex()));
        return new StringAttribute(replaceLiterals(text.substring(1, text.length() - 1)));
    }

    private String replaceLiterals(String str) {
        return !str.contains("\\") ? str : str.replace("\\'", "'").replace("\\\"", "\"").replace("\\t", Inserter.DEFAULT_DELIMITER).replace("\\n", StringUtils.LF).replace("\\r", StringUtils.CR).replace("\\\\", "\\");
    }

    @Override // org.linqs.psl.parser.antlr.PSLBaseVisitor, org.linqs.psl.parser.antlr.PSLVisitor
    public Double visitNumber(PSLParser.NumberContext numberContext) {
        return Double.valueOf(Double.parseDouble(numberContext.getText()));
    }
}
