package org.apache.asterix.lang.sqlpp.rewrites;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.asterix.common.exceptions.CompilationException;
import org.apache.asterix.lang.common.base.AbstractClause;
import org.apache.asterix.lang.common.base.Expression;
import org.apache.asterix.lang.common.base.ILangExpression;
import org.apache.asterix.lang.common.base.IParserFactory;
import org.apache.asterix.lang.common.base.IQueryRewriter;
import org.apache.asterix.lang.common.base.IReturningStatement;
import org.apache.asterix.lang.common.expression.CallExpr;
import org.apache.asterix.lang.common.expression.ListSliceExpression;
import org.apache.asterix.lang.common.expression.VariableExpr;
import org.apache.asterix.lang.common.parser.FunctionParser;
import org.apache.asterix.lang.common.rewrites.LangRewritingContext;
import org.apache.asterix.lang.common.statement.FunctionDecl;
import org.apache.asterix.lang.common.struct.Identifier;
import org.apache.asterix.lang.common.struct.VarIdentifier;
import org.apache.asterix.lang.common.util.FunctionUtil;
import org.apache.asterix.lang.common.visitor.GatherFunctionCallsVisitor;
import org.apache.asterix.lang.common.visitor.base.ILangVisitor;
import org.apache.asterix.lang.sqlpp.clause.AbstractBinaryCorrelateClause;
import org.apache.asterix.lang.sqlpp.clause.FromClause;
import org.apache.asterix.lang.sqlpp.clause.FromTerm;
import org.apache.asterix.lang.sqlpp.clause.HavingClause;
import org.apache.asterix.lang.sqlpp.clause.JoinClause;
import org.apache.asterix.lang.sqlpp.clause.NestClause;
import org.apache.asterix.lang.sqlpp.clause.Projection;
import org.apache.asterix.lang.sqlpp.clause.SelectBlock;
import org.apache.asterix.lang.sqlpp.clause.SelectClause;
import org.apache.asterix.lang.sqlpp.clause.SelectElement;
import org.apache.asterix.lang.sqlpp.clause.SelectRegular;
import org.apache.asterix.lang.sqlpp.clause.SelectSetOperation;
import org.apache.asterix.lang.sqlpp.clause.UnnestClause;
import org.apache.asterix.lang.sqlpp.expression.CaseExpression;
import org.apache.asterix.lang.sqlpp.expression.SelectExpression;
import org.apache.asterix.lang.sqlpp.expression.WindowExpression;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.GenerateColumnNameVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.InlineColumnAliasVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.InlineWithExpressionVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.OperatorExpressionVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SetOperationVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppCaseAggregateExtractionVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppCaseExpressionVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppFunctionCallResolverVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppGroupByAggregationSugarVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppGroupByVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppGroupingSetsVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppInlineUdfsVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppListInputFunctionRewriteVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppSpecialFunctionNameRewriteVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppWindowAggregationSugarVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SqlppWindowRewriteVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.SubstituteGroupbyExpressionWithVariableVisitor;
import org.apache.asterix.lang.sqlpp.rewrites.visitor.VariableCheckAndRewriteVisitor;
import org.apache.asterix.lang.sqlpp.struct.SetOperationRight;
import org.apache.asterix.lang.sqlpp.util.SqlppAstPrintUtil;
import org.apache.asterix.lang.sqlpp.util.SqlppVariableUtil;
import org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor;
import org.apache.asterix.metadata.declared.MetadataProvider;
import org.apache.hyracks.algebricks.common.utils.Pair;
import org.apache.hyracks.util.LogRedactionUtil;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:org/apache/asterix/lang/sqlpp/rewrites/SqlppQueryRewriter.class */
public class SqlppQueryRewriter implements IQueryRewriter {
    private static final Logger LOGGER = LogManager.getLogger(SqlppQueryRewriter.class);
    public static final String INLINE_WITH_OPTION = "inline_with";
    private static final boolean INLINE_WITH_OPTION_DEFAULT = true;
    private final IParserFactory parserFactory;
    private final FunctionParser functionParser;
    private IReturningStatement topExpr;
    private List<FunctionDecl> declaredFunctions;
    private LangRewritingContext context;
    private MetadataProvider metadataProvider;
    private Collection<VarIdentifier> externalVars;
    private boolean isLogEnabled;

    /* loaded from: input_file:org/apache/asterix/lang/sqlpp/rewrites/SqlppQueryRewriter$GatherFunctionCalls.class */
    private static class GatherFunctionCalls extends GatherFunctionCallsVisitor implements ISqlppVisitor<Void, Void> {
        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(FromClause fromClause, Void r6) throws CompilationException {
            Iterator<FromTerm> it = fromClause.getFromTerms().iterator();
            while (it.hasNext()) {
                it.next().accept(this, r6);
            }
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(FromTerm fromTerm, Void r6) throws CompilationException {
            fromTerm.getLeftExpression().accept(this, r6);
            Iterator<AbstractBinaryCorrelateClause> it = fromTerm.getCorrelateClauses().iterator();
            while (it.hasNext()) {
                it.next().accept(this, r6);
            }
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(JoinClause joinClause, Void r6) throws CompilationException {
            joinClause.getRightExpression().accept(this, r6);
            joinClause.getConditionExpression().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(NestClause nestClause, Void r6) throws CompilationException {
            nestClause.getRightExpression().accept(this, r6);
            nestClause.getConditionExpression().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(Projection projection, Void r6) throws CompilationException {
            if (projection.star()) {
                return null;
            }
            projection.getExpression().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(SelectBlock selectBlock, Void r6) throws CompilationException {
            if (selectBlock.hasFromClause()) {
                selectBlock.getFromClause().accept(this, r6);
            }
            if (selectBlock.hasLetWhereClauses()) {
                Iterator<AbstractClause> it = selectBlock.getLetWhereList().iterator();
                while (it.hasNext()) {
                    it.next().accept(this, r6);
                }
            }
            if (selectBlock.hasGroupbyClause()) {
                selectBlock.getGroupbyClause().accept(this, r6);
            }
            if (selectBlock.hasLetHavingClausesAfterGroupby()) {
                Iterator<AbstractClause> it2 = selectBlock.getLetHavingListAfterGroupby().iterator();
                while (it2.hasNext()) {
                    it2.next().accept(this, r6);
                }
            }
            selectBlock.getSelectClause().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(SelectClause selectClause, Void r6) throws CompilationException {
            if (selectClause.selectElement()) {
                selectClause.getSelectElement().accept(this, r6);
                return null;
            }
            selectClause.getSelectRegular().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(SelectElement selectElement, Void r6) throws CompilationException {
            selectElement.getExpression().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(SelectRegular selectRegular, Void r6) throws CompilationException {
            Iterator<Projection> it = selectRegular.getProjections().iterator();
            while (it.hasNext()) {
                it.next().accept(this, r6);
            }
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(SelectSetOperation selectSetOperation, Void r6) throws CompilationException {
            selectSetOperation.getLeftInput().accept(this, r6);
            Iterator<SetOperationRight> it = selectSetOperation.getRightInputs().iterator();
            while (it.hasNext()) {
                it.next().getSetOperationRightInput().accept(this, r6);
            }
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(SelectExpression selectExpression, Void r6) throws CompilationException {
            selectExpression.getSelectSetOperation().accept(this, r6);
            if (selectExpression.hasOrderby()) {
                selectExpression.getOrderbyClause().accept(this, r6);
            }
            if (!selectExpression.hasLimit()) {
                return null;
            }
            selectExpression.getLimitClause().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(UnnestClause unnestClause, Void r6) throws CompilationException {
            unnestClause.getRightExpression().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(HavingClause havingClause, Void r6) throws CompilationException {
            havingClause.getFilterExpression().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(CaseExpression caseExpression, Void r6) throws CompilationException {
            caseExpression.getConditionExpr().accept(this, r6);
            Iterator<Expression> it = caseExpression.getWhenExprs().iterator();
            while (it.hasNext()) {
                it.next().accept(this, r6);
            }
            Iterator<Expression> it2 = caseExpression.getThenExprs().iterator();
            while (it2.hasNext()) {
                it2.next().accept(this, r6);
            }
            caseExpression.getElseExpr().accept(this, r6);
            return null;
        }

        @Override // org.apache.asterix.lang.sqlpp.visitor.base.ISqlppVisitor
        public Void visit(WindowExpression windowExpression, Void r6) throws CompilationException {
            if (windowExpression.hasPartitionList()) {
                Iterator<Expression> it = windowExpression.getPartitionList().iterator();
                while (it.hasNext()) {
                    it.next().accept(this, r6);
                }
            }
            if (windowExpression.hasOrderByList()) {
                Iterator<Expression> it2 = windowExpression.getOrderbyList().iterator();
                while (it2.hasNext()) {
                    it2.next().accept(this, r6);
                }
            }
            if (windowExpression.hasFrameStartExpr()) {
                windowExpression.getFrameStartExpr().accept(this, r6);
            }
            if (windowExpression.hasFrameEndExpr()) {
                windowExpression.getFrameEndExpr().accept(this, r6);
            }
            if (windowExpression.hasWindowFieldList()) {
                Iterator<Pair<Expression, Identifier>> it3 = windowExpression.getWindowFieldList().iterator();
                while (it3.hasNext()) {
                    ((Expression) it3.next().first).accept(this, r6);
                }
            }
            if (windowExpression.hasAggregateFilterExpr()) {
                windowExpression.getAggregateFilterExpr().accept(this, r6);
            }
            Iterator<Expression> it4 = windowExpression.getExprList().iterator();
            while (it4.hasNext()) {
                it4.next().accept(this, r6);
            }
            return null;
        }

        public Void visit(ListSliceExpression listSliceExpression, Void r6) throws CompilationException {
            listSliceExpression.getExpr().accept(this, r6);
            listSliceExpression.getStartIndexExpression().accept(this, r6);
            if (!listSliceExpression.hasEndExpression()) {
                return null;
            }
            listSliceExpression.getEndIndexExpression().accept(this, r6);
            return null;
        }
    }

    public SqlppQueryRewriter(IParserFactory iParserFactory) {
        this.parserFactory = iParserFactory;
        this.functionParser = new FunctionParser(iParserFactory);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setup(List<FunctionDecl> list, IReturningStatement iReturningStatement, MetadataProvider metadataProvider, LangRewritingContext langRewritingContext, Collection<VarIdentifier> collection) throws CompilationException {
        this.topExpr = iReturningStatement;
        this.context = langRewritingContext;
        this.declaredFunctions = list;
        this.metadataProvider = metadataProvider;
        this.externalVars = collection != null ? collection : Collections.emptyList();
        this.isLogEnabled = LOGGER.isTraceEnabled();
        logExpression("Starting AST rewrites on", "");
    }

    public void rewrite(List<FunctionDecl> list, IReturningStatement iReturningStatement, MetadataProvider metadataProvider, LangRewritingContext langRewritingContext, boolean z, Collection<VarIdentifier> collection) throws CompilationException {
        if (iReturningStatement == null) {
            return;
        }
        setup(list, iReturningStatement, metadataProvider, langRewritingContext, collection);
        resolveFunctionCalls();
        generateColumnNames();
        substituteGroupbyKeyExpression();
        rewriteGroupBys();
        rewriteSetOperations();
        inlineColumnAlias();
        rewriteWindowExpressions();
        rewriteGroupingSets();
        variableCheckAndRewrite();
        extractAggregatesFromCaseExpressions();
        rewriteGroupByAggregationSugar();
        rewriteWindowAggregationSugar();
        rewriteOperatorExpression();
        rewriteCaseExpressions();
        rewriteListInputFunctions();
        inlineDeclaredUdfs(z);
        rewriteSpecialFunctionNames();
        inlineWithExpressions();
        iReturningStatement.setVarCounter(langRewritingContext.getVarCounter().get());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteGroupByAggregationSugar() throws CompilationException {
        rewriteTopExpr(new SqlppGroupByAggregationSugarVisitor(this.context, this.externalVars), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteListInputFunctions() throws CompilationException {
        rewriteTopExpr(new SqlppListInputFunctionRewriteVisitor(), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void resolveFunctionCalls() throws CompilationException {
        rewriteTopExpr(new SqlppFunctionCallResolverVisitor(this.metadataProvider, this.declaredFunctions), null);
    }

    protected void rewriteSpecialFunctionNames() throws CompilationException {
        rewriteTopExpr(new SqlppSpecialFunctionNameRewriteVisitor(), null);
    }

    protected void inlineWithExpressions() throws CompilationException {
        if (this.metadataProvider.getBooleanProperty(INLINE_WITH_OPTION, true)) {
            rewriteTopExpr(new InlineWithExpressionVisitor(this.context, this.metadataProvider), null);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void generateColumnNames() throws CompilationException {
        rewriteTopExpr(new GenerateColumnNameVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void substituteGroupbyKeyExpression() throws CompilationException {
        rewriteTopExpr(new SubstituteGroupbyExpressionWithVariableVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteSetOperations() throws CompilationException {
        rewriteTopExpr(new SetOperationVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteOperatorExpression() throws CompilationException {
        rewriteTopExpr(new OperatorExpressionVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void inlineColumnAlias() throws CompilationException {
        rewriteTopExpr(new InlineColumnAliasVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void variableCheckAndRewrite() throws CompilationException {
        rewriteTopExpr(new VariableCheckAndRewriteVisitor(this.context, this.metadataProvider, this.externalVars), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteGroupBys() throws CompilationException {
        rewriteTopExpr(new SqlppGroupByVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteGroupingSets() throws CompilationException {
        rewriteTopExpr(new SqlppGroupingSetsVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteWindowExpressions() throws CompilationException {
        rewriteTopExpr(new SqlppWindowRewriteVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteWindowAggregationSugar() throws CompilationException {
        rewriteTopExpr(new SqlppWindowAggregationSugarVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void extractAggregatesFromCaseExpressions() throws CompilationException {
        rewriteTopExpr(new SqlppCaseAggregateExtractionVisitor(this.context), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rewriteCaseExpressions() throws CompilationException {
        rewriteTopExpr(new SqlppCaseExpressionVisitor(), null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void inlineDeclaredUdfs(boolean z) throws CompilationException {
        ArrayList arrayList = new ArrayList();
        Iterator<FunctionDecl> it = this.declaredFunctions.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getSignature());
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it2 = this.topExpr.getDirectlyEnclosedExpressions().iterator();
        while (it2.hasNext()) {
            arrayList2.addAll(FunctionUtil.retrieveUsedStoredFunctions(this.metadataProvider, (Expression) it2.next(), arrayList, (List) null, this::getFunctionCalls, this.functionParser, this.metadataProvider.getDefaultDataverseName()));
        }
        this.declaredFunctions.addAll(arrayList2);
        if (z && !this.declaredFunctions.isEmpty()) {
            do {
            } while (((Boolean) rewriteTopExpr(new SqlppInlineUdfsVisitor(this.context, new SqlppFunctionBodyRewriterFactory(this.parserFactory), this.declaredFunctions, this.metadataProvider), this.declaredFunctions)).booleanValue());
        }
        this.declaredFunctions.removeAll(arrayList2);
    }

    private <R, T> R rewriteTopExpr(ILangVisitor<R, T> iLangVisitor, T t) throws CompilationException {
        R r = (R) this.topExpr.accept(iLangVisitor, t);
        logExpression(">>>> AST After", iLangVisitor.getClass().getSimpleName());
        return r;
    }

    private void logExpression(String str, String str2) throws CompilationException {
        if (this.isLogEnabled) {
            LOGGER.trace("{} {}\n{}", str, str2, LogRedactionUtil.userData(SqlppAstPrintUtil.toString((ILangExpression) this.topExpr)));
        }
    }

    public Set<CallExpr> getFunctionCalls(Expression expression) throws CompilationException {
        GatherFunctionCalls gatherFunctionCalls = new GatherFunctionCalls();
        expression.accept(gatherFunctionCalls, (Object) null);
        return gatherFunctionCalls.getCalls();
    }

    public Set<VariableExpr> getExternalVariables(Expression expression) throws CompilationException {
        Set<VariableExpr> freeVariables = SqlppVariableUtil.getFreeVariables(expression);
        HashSet hashSet = new HashSet();
        for (VariableExpr variableExpr : freeVariables) {
            if (SqlppVariableUtil.isExternalVariableReference(variableExpr)) {
                hashSet.add(variableExpr);
            }
        }
        return hashSet;
    }
}
