package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;

/* loaded from: input_file:io/trino/sql/planner/ExpressionExtractor.class */
public final class ExpressionExtractor {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/ExpressionExtractor$Visitor.class */
    public static class Visitor extends SimplePlanVisitor<Void> {
        private final Consumer<Expression> consumer;
        private final boolean recursive;
        private final Lookup lookup;

        Visitor(Consumer<Expression> consumer, boolean z, Lookup lookup) {
            this.consumer = (Consumer) Objects.requireNonNull(consumer, "consumer is null");
            this.recursive = z;
            this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.planner.SimplePlanVisitor, io.trino.sql.planner.plan.PlanVisitor
        public Void visitPlan(PlanNode planNode, Void r6) {
            if (this.recursive) {
                return super.visitPlan(planNode, (PlanNode) r6);
            }
            return null;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitGroupReference(GroupReference groupReference, Void r6) {
            return (Void) this.lookup.resolve(groupReference).accept(this, r6);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitAggregation(AggregationNode aggregationNode, Void r6) {
            Iterator<AggregationNode.Aggregation> it = aggregationNode.getAggregations().values().iterator();
            while (it.hasNext()) {
                it.next().getArguments().forEach(this.consumer);
            }
            return (Void) super.visitAggregation(aggregationNode, (AggregationNode) r6);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitFilter(FilterNode filterNode, Void r6) {
            this.consumer.accept(filterNode.getPredicate());
            return (Void) super.visitFilter(filterNode, (FilterNode) r6);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitProject(ProjectNode projectNode, Void r6) {
            projectNode.getAssignments().getExpressions().forEach(this.consumer);
            return (Void) super.visitProject(projectNode, (ProjectNode) r6);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitJoin(JoinNode joinNode, Void r6) {
            joinNode.getFilter().ifPresent(this.consumer);
            return (Void) super.visitJoin(joinNode, (JoinNode) r6);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitValues(ValuesNode valuesNode, Void r6) {
            valuesNode.getRows().ifPresent(list -> {
                list.forEach(this.consumer);
            });
            return (Void) super.visitValues(valuesNode, (ValuesNode) r6);
        }
    }

    public static List<Expression> extractExpressions(PlanNode planNode) {
        return extractExpressions(planNode, Lookup.noLookup());
    }

    public static List<Expression> extractExpressions(PlanNode planNode, Lookup lookup) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(lookup, "lookup is null");
        ImmutableList.Builder builder = ImmutableList.builder();
        Objects.requireNonNull(builder);
        planNode.accept(new Visitor((v1) -> {
            r3.add(v1);
        }, true, lookup), null);
        return builder.build();
    }

    public static List<Expression> extractExpressionsNonRecursive(PlanNode planNode) {
        ImmutableList.Builder builder = ImmutableList.builder();
        Objects.requireNonNull(builder);
        planNode.accept(new Visitor((v1) -> {
            r3.add(v1);
        }, false, Lookup.noLookup()), null);
        return builder.build();
    }

    public static void forEachExpression(PlanNode planNode, Consumer<Expression> consumer) {
        planNode.accept(new Visitor(consumer, true, Lookup.noLookup()), null);
    }

    private ExpressionExtractor() {
    }
}
