package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.sql.planner.PartitioningScheme;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.rule.PruneTableScanColumns;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.ApplyNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.DeleteNode;
import io.prestosql.sql.planner.plan.DistinctLimitNode;
import io.prestosql.sql.planner.plan.ExceptNode;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.ExplainAnalyzeNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.GroupIdNode;
import io.prestosql.sql.planner.plan.IndexJoinNode;
import io.prestosql.sql.planner.plan.IndexSourceNode;
import io.prestosql.sql.planner.plan.IntersectNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.MarkDistinctNode;
import io.prestosql.sql.planner.plan.OffsetNode;
import io.prestosql.sql.planner.plan.OutputNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.RowNumberNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SetOperationNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.planner.plan.StatisticAggregations;
import io.prestosql.sql.planner.plan.StatisticsWriterNode;
import io.prestosql.sql.planner.plan.TableFinishNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.TopNRowNumberNode;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs.class */
public class PruneUnreferencedOutputs implements PlanOptimizer {
    private final Metadata metadata;
    private final TypeAnalyzer typeAnalyzer;

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/PruneUnreferencedOutputs$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<Set<Symbol>> {
        private final Metadata metadata;
        private final TypeProvider types;
        private final TypeAnalyzer typeAnalyzer;
        private final SymbolAllocator symbolAllocator;
        private final Session session;

        public Rewriter(Metadata metadata, TypeProvider typeProvider, TypeAnalyzer typeAnalyzer, SymbolAllocator symbolAllocator, Session session) {
            this.metadata = metadata;
            this.types = typeProvider;
            this.typeAnalyzer = typeAnalyzer;
            this.symbolAllocator = symbolAllocator;
            this.session = session;
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitExplainAnalyze(ExplainAnalyzeNode explainAnalyzeNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return rewriteContext.defaultRewrite(explainAnalyzeNode, ImmutableSet.copyOf(explainAnalyzeNode.getSource().getOutputSymbols()));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitExchange(ExchangeNode exchangeNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            HashSet newHashSet = Sets.newHashSet(rewriteContext.get());
            Optional<Symbol> hashColumn = exchangeNode.getPartitioningScheme().getHashColumn();
            Objects.requireNonNull(newHashSet);
            hashColumn.ifPresent((v1) -> {
                r1.add(v1);
            });
            newHashSet.addAll(exchangeNode.getPartitioningScheme().getPartitioning().getColumns());
            exchangeNode.getOrderingScheme().ifPresent(orderingScheme -> {
                newHashSet.addAll(orderingScheme.getOrderBy());
            });
            ArrayList arrayList = new ArrayList(exchangeNode.getInputs().size());
            for (int i = 0; i < exchangeNode.getInputs().size(); i++) {
                arrayList.add(new ArrayList());
            }
            ArrayList arrayList2 = new ArrayList(exchangeNode.getOutputSymbols().size());
            for (int i2 = 0; i2 < exchangeNode.getOutputSymbols().size(); i2++) {
                Symbol symbol = exchangeNode.getOutputSymbols().get(i2);
                if (newHashSet.contains(symbol)) {
                    arrayList2.add(symbol);
                    for (int i3 = 0; i3 < exchangeNode.getInputs().size(); i3++) {
                        ((List) arrayList.get(i3)).add(exchangeNode.getInputs().get(i3).get(i2));
                    }
                }
            }
            PartitioningScheme partitioningScheme = new PartitioningScheme(exchangeNode.getPartitioningScheme().getPartitioning(), arrayList2, exchangeNode.getPartitioningScheme().getHashColumn(), exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), exchangeNode.getPartitioningScheme().getBucketToPartition());
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i4 = 0; i4 < exchangeNode.getSources().size(); i4++) {
                builder.add(rewriteContext.rewrite(exchangeNode.getSources().get(i4), ImmutableSet.builder().addAll((Iterable) arrayList.get(i4)).build()));
            }
            return new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getScope(), partitioningScheme, builder.build(), arrayList, exchangeNode.getOrderingScheme());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitJoin(JoinNode joinNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            Set set = (Set) joinNode.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of());
            ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(rewriteContext.get()).addAll(set).addAll(Iterables.transform(joinNode.getCriteria(), (v0) -> {
                return v0.getLeft();
            }));
            Optional<Symbol> leftHashSymbol = joinNode.getLeftHashSymbol();
            Objects.requireNonNull(addAll);
            leftHashSymbol.ifPresent((v1) -> {
                r1.add(v1);
            });
            ImmutableSet.Builder addAll2 = ImmutableSet.builder().addAll(rewriteContext.get()).addAll(set).addAll(Iterables.transform(joinNode.getCriteria(), (v0) -> {
                return v0.getRight();
            }));
            Optional<Symbol> rightHashSymbol = joinNode.getRightHashSymbol();
            Objects.requireNonNull(addAll2);
            rightHashSymbol.ifPresent((v1) -> {
                r1.add(v1);
            });
            PlanNode rewrite = rewriteContext.rewrite(joinNode.getLeft(), addAll.build());
            PlanNode rewrite2 = rewriteContext.rewrite(joinNode.getRight(), addAll2.build());
            Stream<Symbol> stream = joinNode.getLeftOutputSymbols().stream();
            Set<Symbol> set2 = rewriteContext.get();
            Objects.requireNonNull(set2);
            List list = (List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).distinct().collect(ImmutableList.toImmutableList());
            Stream<Symbol> stream2 = joinNode.getRightOutputSymbols().stream();
            Set<Symbol> set3 = rewriteContext.get();
            Objects.requireNonNull(set3);
            return new JoinNode(joinNode.getId(), joinNode.getType(), rewrite, rewrite2, joinNode.getCriteria(), list, (List) stream2.filter((v1) -> {
                return r1.contains(v1);
            }).distinct().collect(ImmutableList.toImmutableList()), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitSemiJoin(SemiJoinNode semiJoinNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            if (!rewriteContext.get().contains(semiJoinNode.getSemiJoinOutput())) {
                return rewriteContext.rewrite(semiJoinNode.getSource(), rewriteContext.get());
            }
            ImmutableSet.Builder builder = ImmutableSet.builder();
            builder.addAll(rewriteContext.get()).add(semiJoinNode.getSourceJoinSymbol());
            if (semiJoinNode.getSourceHashSymbol().isPresent()) {
                builder.add(semiJoinNode.getSourceHashSymbol().get());
            }
            ImmutableSet build = builder.build();
            ImmutableSet.Builder builder2 = ImmutableSet.builder();
            builder2.add(semiJoinNode.getFilteringSourceJoinSymbol());
            if (semiJoinNode.getFilteringSourceHashSymbol().isPresent()) {
                builder2.add(semiJoinNode.getFilteringSourceHashSymbol().get());
            }
            ImmutableSet build2 = builder2.build();
            return new SemiJoinNode(semiJoinNode.getId(), rewriteContext.rewrite(semiJoinNode.getSource(), build), rewriteContext.rewrite(semiJoinNode.getFilteringSource(), build2), semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput(), semiJoinNode.getSourceHashSymbol(), semiJoinNode.getFilteringSourceHashSymbol(), semiJoinNode.getDistributionType());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitSpatialJoin(SpatialJoinNode spatialJoinNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableSet build = ImmutableSet.builder().addAll(SymbolsExtractor.extractUnique(spatialJoinNode.getFilter())).addAll(rewriteContext.get()).build();
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Optional<Symbol> leftPartitionSymbol = spatialJoinNode.getLeftPartitionSymbol();
            Objects.requireNonNull(builder);
            leftPartitionSymbol.map((v1) -> {
                return r1.add(v1);
            });
            ImmutableSet.Builder builder2 = ImmutableSet.builder();
            Optional<Symbol> rightPartitionSymbol = spatialJoinNode.getRightPartitionSymbol();
            Objects.requireNonNull(builder2);
            rightPartitionSymbol.map((v1) -> {
                return r1.add(v1);
            });
            PlanNode rewrite = rewriteContext.rewrite(spatialJoinNode.getLeft(), builder.addAll(build).build());
            PlanNode rewrite2 = rewriteContext.rewrite(spatialJoinNode.getRight(), builder2.addAll(build).build());
            Stream<Symbol> stream = spatialJoinNode.getOutputSymbols().stream();
            Set<Symbol> set = rewriteContext.get();
            Objects.requireNonNull(set);
            return new SpatialJoinNode(spatialJoinNode.getId(), spatialJoinNode.getType(), rewrite, rewrite2, (List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).distinct().collect(ImmutableList.toImmutableList()), spatialJoinNode.getFilter(), spatialJoinNode.getLeftPartitionSymbol(), spatialJoinNode.getRightPartitionSymbol(), spatialJoinNode.getKdbTree());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitIndexJoin(IndexJoinNode indexJoinNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            builder.addAll(rewriteContext.get()).addAll(Iterables.transform(indexJoinNode.getCriteria(), (v0) -> {
                return v0.getProbe();
            }));
            if (indexJoinNode.getProbeHashSymbol().isPresent()) {
                builder.add(indexJoinNode.getProbeHashSymbol().get());
            }
            ImmutableSet build = builder.build();
            ImmutableSet.Builder builder2 = ImmutableSet.builder();
            builder2.addAll(rewriteContext.get()).addAll(Iterables.transform(indexJoinNode.getCriteria(), (v0) -> {
                return v0.getIndex();
            }));
            if (indexJoinNode.getIndexHashSymbol().isPresent()) {
                builder2.add(indexJoinNode.getIndexHashSymbol().get());
            }
            ImmutableSet build2 = builder2.build();
            return new IndexJoinNode(indexJoinNode.getId(), indexJoinNode.getType(), rewriteContext.rewrite(indexJoinNode.getProbeSource(), build), rewriteContext.rewrite(indexJoinNode.getIndexSource(), build2), indexJoinNode.getCriteria(), indexJoinNode.getProbeHashSymbol(), indexJoinNode.getIndexHashSymbol());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitIndexSource(IndexSourceNode indexSourceNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            Stream<Symbol> stream = indexSourceNode.getOutputSymbols().stream();
            Set<Symbol> set = rewriteContext.get();
            Objects.requireNonNull(set);
            List list = (List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList());
            Stream<Symbol> stream2 = indexSourceNode.getLookupSymbols().stream();
            Set<Symbol> set2 = rewriteContext.get();
            Objects.requireNonNull(set2);
            Set set3 = (Set) stream2.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableSet.toImmutableSet());
            Stream stream3 = list.stream();
            Function identity = Function.identity();
            Map<Symbol, ColumnHandle> assignments = indexSourceNode.getAssignments();
            Objects.requireNonNull(assignments);
            return new IndexSourceNode(indexSourceNode.getId(), indexSourceNode.getIndexHandle(), indexSourceNode.getTableHandle(), set3, list, (Map) stream3.collect(Collectors.toMap(identity, (v1) -> {
                return r2.get(v1);
            })));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(aggregationNode.getGroupingKeys());
            if (aggregationNode.getHashSymbol().isPresent()) {
                addAll.add(aggregationNode.getHashSymbol().get());
            }
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
                Symbol key = entry.getKey();
                if (rewriteContext.get().contains(key)) {
                    AggregationNode.Aggregation value = entry.getValue();
                    addAll.addAll(SymbolsExtractor.extractUnique(value));
                    builder.put(key, value);
                }
            }
            return new AggregationNode(aggregationNode.getId(), rewriteContext.rewrite(aggregationNode.getSource(), addAll.build()), builder.build(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitWindow(WindowNode windowNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            Map map = (Map) windowNode.getWindowFunctions().entrySet().stream().filter(entry -> {
                return ((Set) rewriteContext.get()).contains(entry.getKey());
            }).collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
            if (map.isEmpty()) {
                return rewriteContext.rewrite(windowNode.getSource(), rewriteContext.get());
            }
            ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(rewriteContext.get()).addAll(windowNode.getPartitionBy());
            windowNode.getOrderingScheme().ifPresent(orderingScheme -> {
                List<Symbol> orderBy = orderingScheme.getOrderBy();
                Objects.requireNonNull(addAll);
                orderBy.forEach((v1) -> {
                    r1.add(v1);
                });
            });
            if (windowNode.getHashSymbol().isPresent()) {
                addAll.add(windowNode.getHashSymbol().get());
            }
            Stream map2 = map.values().stream().map(SymbolsExtractor::extractUnique);
            Objects.requireNonNull(addAll);
            map2.forEach((v1) -> {
                r1.addAll(v1);
            });
            return new WindowNode(windowNode.getId(), rewriteContext.rewrite(windowNode.getSource(), addAll.build()), windowNode.getSpecification(), map, windowNode.getHashSymbol(), windowNode.getPrePartitionedInputs(), windowNode.getPreSortedOrderPrefix());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitTableScan(TableScanNode tableScanNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return PruneTableScanColumns.pruneColumns(this.metadata, this.types, this.session, tableScanNode, rewriteContext.get()).orElse(tableScanNode);
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitFilter(FilterNode filterNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new FilterNode(filterNode.getId(), rewriteContext.rewrite(filterNode.getSource(), ImmutableSet.builder().addAll(SymbolsExtractor.extractUnique(filterNode.getPredicate())).addAll(rewriteContext.get()).build()), filterNode.getPredicate());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitGroupId(GroupIdNode groupIdNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Stream<Symbol> stream = groupIdNode.getAggregationArguments().stream();
            Set<Symbol> set = rewriteContext.get();
            Objects.requireNonNull(set);
            List list = (List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(Collectors.toList());
            builder.addAll(list);
            ImmutableList.Builder builder2 = ImmutableList.builder();
            HashMap hashMap = new HashMap();
            for (List<Symbol> list2 : groupIdNode.getGroupingSets()) {
                ImmutableList.Builder builder3 = ImmutableList.builder();
                for (Symbol symbol : list2) {
                    if (rewriteContext.get().contains(symbol)) {
                        builder3.add(symbol);
                        hashMap.putIfAbsent(symbol, groupIdNode.getGroupingColumns().get(symbol));
                        builder.add(groupIdNode.getGroupingColumns().get(symbol));
                    }
                }
                builder2.add(builder3.build());
            }
            return new GroupIdNode(groupIdNode.getId(), rewriteContext.rewrite(groupIdNode.getSource(), builder.build()), builder2.build(), hashMap, list, groupIdNode.getGroupIdSymbol());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitMarkDistinct(MarkDistinctNode markDistinctNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            if (!rewriteContext.get().contains(markDistinctNode.getMarkerSymbol())) {
                return rewriteContext.rewrite(markDistinctNode.getSource(), rewriteContext.get());
            }
            ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(markDistinctNode.getDistinctSymbols()).addAll((Iterable) rewriteContext.get().stream().filter(symbol -> {
                return !symbol.equals(markDistinctNode.getMarkerSymbol());
            }).collect(ImmutableList.toImmutableList()));
            if (markDistinctNode.getHashSymbol().isPresent()) {
                addAll.add(markDistinctNode.getHashSymbol().get());
            }
            return new MarkDistinctNode(markDistinctNode.getId(), rewriteContext.rewrite(markDistinctNode.getSource(), addAll.build()), markDistinctNode.getMarkerSymbol(), markDistinctNode.getDistinctSymbols(), markDistinctNode.getHashSymbol());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitUnnest(UnnestNode unnestNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(rewriteContext.get());
            unnestNode.getFilter().ifPresent(expression -> {
                addAll.addAll(SymbolsExtractor.extractUnique(expression));
            });
            ImmutableSet build = addAll.build();
            Stream<Symbol> stream = unnestNode.getReplicateSymbols().stream();
            Objects.requireNonNull(build);
            List list = (List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList());
            Optional<Symbol> ordinalitySymbol = unnestNode.getOrdinalitySymbol();
            Objects.requireNonNull(build);
            Optional<Symbol> filter = ordinalitySymbol.filter((v1) -> {
                return r1.contains(v1);
            });
            ImmutableSet.Builder addAll2 = ImmutableSet.builder().addAll(list);
            Stream<R> map = unnestNode.getMappings().stream().map((v0) -> {
                return v0.getInput();
            });
            Objects.requireNonNull(addAll2);
            map.forEach((v1) -> {
                r1.add(v1);
            });
            return new UnnestNode(unnestNode.getId(), rewriteContext.rewrite(unnestNode.getSource(), addAll2.build()), list, unnestNode.getMappings(), filter, unnestNode.getJoinType(), unnestNode.getFilter());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitProject(ProjectNode projectNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Assignments.Builder builder2 = Assignments.builder();
            projectNode.getAssignments().forEach((symbol, expression) -> {
                if (((Set) rewriteContext.get()).contains(symbol)) {
                    builder.addAll(SymbolsExtractor.extractUnique(expression));
                    builder2.put(symbol, expression);
                }
            });
            return new ProjectNode(projectNode.getId(), rewriteContext.rewrite(projectNode.getSource(), builder.build()), builder2.build());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitOutput(OutputNode outputNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new OutputNode(outputNode.getId(), rewriteContext.rewrite(outputNode.getSource(), ImmutableSet.copyOf(outputNode.getOutputSymbols())), outputNode.getColumnNames(), outputNode.getOutputSymbols());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitOffset(OffsetNode offsetNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new OffsetNode(offsetNode.getId(), rewriteContext.rewrite(offsetNode.getSource(), ImmutableSet.builder().addAll(rewriteContext.get()).build()), offsetNode.getCount());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitLimit(LimitNode limitNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new LimitNode(limitNode.getId(), rewriteContext.rewrite(limitNode.getSource(), ImmutableSet.builder().addAll(rewriteContext.get()).addAll((Iterable) limitNode.getTiesResolvingScheme().map((v0) -> {
                return v0.getOrderBy();
            }).orElse(ImmutableList.of())).build()), limitNode.getCount(), limitNode.getTiesResolvingScheme(), limitNode.isPartial());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitDistinctLimit(DistinctLimitNode distinctLimitNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new DistinctLimitNode(distinctLimitNode.getId(), rewriteContext.rewrite(distinctLimitNode.getSource(), distinctLimitNode.getHashSymbol().isPresent() ? ImmutableSet.copyOf(Iterables.concat(distinctLimitNode.getDistinctSymbols(), ImmutableList.of(distinctLimitNode.getHashSymbol().get()))) : ImmutableSet.copyOf(distinctLimitNode.getDistinctSymbols())), distinctLimitNode.getLimit(), distinctLimitNode.isPartial(), distinctLimitNode.getDistinctSymbols(), distinctLimitNode.getHashSymbol());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitTopN(TopNNode topNNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new TopNNode(topNNode.getId(), rewriteContext.rewrite(topNNode.getSource(), ImmutableSet.builder().addAll(rewriteContext.get()).addAll(topNNode.getOrderingScheme().getOrderBy()).build()), topNNode.getCount(), topNNode.getOrderingScheme(), topNNode.getStep());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitRowNumber(RowNumberNode rowNumberNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            if (!rewriteContext.get().contains(rowNumberNode.getRowNumberSymbol())) {
                PlanNode rewrite = rewriteContext.rewrite(rowNumberNode.getSource(), rewriteContext.get());
                if (rowNumberNode.getMaxRowCountPerPartition().isEmpty()) {
                    return rewrite;
                }
                if (rowNumberNode.getPartitionBy().isEmpty()) {
                    return new LimitNode(rowNumberNode.getId(), rewrite, rowNumberNode.getMaxRowCountPerPartition().get().intValue(), false);
                }
            }
            ImmutableSet.Builder builder = ImmutableSet.builder();
            ImmutableSet.Builder addAll = builder.addAll(rewriteContext.get()).addAll(rowNumberNode.getPartitionBy());
            if (rowNumberNode.getHashSymbol().isPresent()) {
                builder.add(rowNumberNode.getHashSymbol().get());
            }
            return new RowNumberNode(rowNumberNode.getId(), rewriteContext.rewrite(rowNumberNode.getSource(), addAll.build()), rowNumberNode.getPartitionBy(), rowNumberNode.isOrderSensitive(), rowNumberNode.getRowNumberSymbol(), rowNumberNode.getMaxRowCountPerPartition(), rowNumberNode.getHashSymbol());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitTopNRowNumber(TopNRowNumberNode topNRowNumberNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(rewriteContext.get()).addAll(topNRowNumberNode.getPartitionBy()).addAll(topNRowNumberNode.getOrderingScheme().getOrderBy());
            if (topNRowNumberNode.getHashSymbol().isPresent()) {
                addAll.add(topNRowNumberNode.getHashSymbol().get());
            }
            return new TopNRowNumberNode(topNRowNumberNode.getId(), rewriteContext.rewrite(topNRowNumberNode.getSource(), addAll.build()), topNRowNumberNode.getSpecification(), topNRowNumberNode.getRowNumberSymbol(), topNRowNumberNode.getMaxRowCountPerPartition(), topNRowNumberNode.isPartial(), topNRowNumberNode.getHashSymbol());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitSort(SortNode sortNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new SortNode(sortNode.getId(), rewriteContext.rewrite(sortNode.getSource(), ImmutableSet.copyOf(Iterables.concat(rewriteContext.get(), sortNode.getOrderingScheme().getOrderBy()))), sortNode.getOrderingScheme(), sortNode.isPartial());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitTableWriter(TableWriterNode tableWriterNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(tableWriterNode.getColumns());
            if (tableWriterNode.getPartitioningScheme().isPresent()) {
                PartitioningScheme partitioningScheme = tableWriterNode.getPartitioningScheme().get();
                Set<Symbol> columns = partitioningScheme.getPartitioning().getColumns();
                Objects.requireNonNull(addAll);
                columns.forEach((v1) -> {
                    r1.add(v1);
                });
                Optional<Symbol> hashColumn = partitioningScheme.getHashColumn();
                Objects.requireNonNull(addAll);
                hashColumn.ifPresent((v1) -> {
                    r1.add(v1);
                });
            }
            if (tableWriterNode.getStatisticsAggregation().isPresent()) {
                StatisticAggregations statisticAggregations = tableWriterNode.getStatisticsAggregation().get();
                addAll.addAll(statisticAggregations.getGroupingSymbols());
                statisticAggregations.getAggregations().values().forEach(aggregation -> {
                    addAll.addAll(SymbolsExtractor.extractUnique(aggregation));
                });
            }
            return new TableWriterNode(tableWriterNode.getId(), rewriteContext.rewrite(tableWriterNode.getSource(), addAll.build()), tableWriterNode.getTarget(), tableWriterNode.getRowCountSymbol(), tableWriterNode.getFragmentSymbol(), tableWriterNode.getColumns(), tableWriterNode.getColumnNames(), tableWriterNode.getPartitioningScheme(), tableWriterNode.getStatisticsAggregation(), tableWriterNode.getStatisticsAggregationDescriptor());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitStatisticsWriterNode(StatisticsWriterNode statisticsWriterNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new StatisticsWriterNode(statisticsWriterNode.getId(), rewriteContext.rewrite(statisticsWriterNode.getSource(), ImmutableSet.copyOf(statisticsWriterNode.getSource().getOutputSymbols())), statisticsWriterNode.getTarget(), statisticsWriterNode.getRowCountSymbol(), statisticsWriterNode.isRowCountEnabled(), statisticsWriterNode.getDescriptor());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitTableFinish(TableFinishNode tableFinishNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new TableFinishNode(tableFinishNode.getId(), rewriteContext.rewrite(tableFinishNode.getSource(), ImmutableSet.copyOf(tableFinishNode.getSource().getOutputSymbols())), tableFinishNode.getTarget(), tableFinishNode.getRowCountSymbol(), tableFinishNode.getStatisticsAggregation(), tableFinishNode.getStatisticsAggregationDescriptor());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitDelete(DeleteNode deleteNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return new DeleteNode(deleteNode.getId(), rewriteContext.rewrite(deleteNode.getSource(), ImmutableSet.of(deleteNode.getRowId())), deleteNode.getTarget(), deleteNode.getRowId(), deleteNode.getOutputSymbols());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitUnion(UnionNode unionNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
            for (Symbol symbol : unionNode.getOutputSymbols()) {
                if (rewriteContext.get().contains(symbol)) {
                    builder.putAll(symbol, unionNode.getSymbolMapping().get(symbol));
                }
            }
            ImmutableListMultimap build = builder.build();
            ImmutableList.Builder builder2 = ImmutableList.builder();
            for (int i = 0; i < unionNode.getSources().size(); i++) {
                ImmutableSet.Builder builder3 = ImmutableSet.builder();
                Iterator it = build.asMap().values().iterator();
                while (it.hasNext()) {
                    builder3.add((Symbol) Iterables.get((Collection) it.next(), i));
                }
                builder2.add(rewriteContext.rewrite(unionNode.getSources().get(i), builder3.build()));
            }
            return new UnionNode(unionNode.getId(), builder2.build(), build, ImmutableList.copyOf(build.keySet()));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitIntersect(IntersectNode intersectNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return rewriteSetOperationChildren(intersectNode, rewriteContext);
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitExcept(ExceptNode exceptNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return rewriteSetOperationChildren(exceptNode, rewriteContext);
        }

        private PlanNode rewriteSetOperationChildren(SetOperationNode setOperationNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < setOperationNode.getSources().size(); i++) {
                builder.add(rewriteContext.rewrite(setOperationNode.getSources().get(i), ImmutableSet.copyOf(setOperationNode.sourceOutputLayout(i))));
            }
            return setOperationNode.replaceChildren(builder.build());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitValues(ValuesNode valuesNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            ImmutableList.Builder builder = ImmutableList.builder();
            ImmutableList.Builder builder2 = ImmutableList.builder();
            for (int i = 0; i < valuesNode.getRows().size(); i++) {
                builder2.add(ImmutableList.builder());
            }
            ImmutableList build = builder2.build();
            for (int i2 = 0; i2 < valuesNode.getOutputSymbols().size(); i2++) {
                Symbol symbol = valuesNode.getOutputSymbols().get(i2);
                if (rewriteContext.get().contains(symbol)) {
                    builder.add(symbol);
                    for (int i3 = 0; i3 < valuesNode.getRows().size(); i3++) {
                        ((ImmutableList.Builder) build.get(i3)).add(valuesNode.getRows().get(i3).get(i2));
                    }
                }
            }
            return new ValuesNode(valuesNode.getId(), builder.build(), (List) build.stream().map((v0) -> {
                return v0.build();
            }).collect(ImmutableList.toImmutableList()));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitApply(ApplyNode applyNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            if (Sets.intersection(applyNode.getSubqueryAssignments().getSymbols(), rewriteContext.get()).isEmpty()) {
                return rewriteContext.rewrite(applyNode.getInput(), rewriteContext.get());
            }
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Assignments.Builder builder2 = Assignments.builder();
            for (Map.Entry<Symbol, Expression> entry : applyNode.getSubqueryAssignments().getMap().entrySet()) {
                Symbol key = entry.getKey();
                Expression value = entry.getValue();
                if (rewriteContext.get().contains(key)) {
                    builder.addAll(SymbolsExtractor.extractUnique(value));
                    builder2.put(key, value);
                }
            }
            ImmutableSet build = builder.build();
            PlanNode rewrite = rewriteContext.rewrite(applyNode.getSubquery(), build);
            Set<Symbol> extractUnique = SymbolsExtractor.extractUnique(rewrite);
            Stream<Symbol> stream = applyNode.getCorrelation().stream();
            Objects.requireNonNull(extractUnique);
            List list = (List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList());
            return new ApplyNode(applyNode.getId(), rewriteContext.rewrite(applyNode.getInput(), ImmutableSet.builder().addAll(rewriteContext.get()).addAll(list).addAll(build).build()), rewrite, builder2.build(), list, applyNode.getOriginSubquery());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitAssignUniqueId(AssignUniqueId assignUniqueId, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            return !rewriteContext.get().contains(assignUniqueId.getIdColumn()) ? rewriteContext.rewrite(assignUniqueId.getSource(), rewriteContext.get()) : rewriteContext.defaultRewrite(assignUniqueId, rewriteContext.get());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitCorrelatedJoin(CorrelatedJoinNode correlatedJoinNode, SimplePlanRewriter.RewriteContext<Set<Symbol>> rewriteContext) {
            Set<Symbol> extractUnique = SymbolsExtractor.extractUnique(correlatedJoinNode.getFilter());
            PlanNode rewrite = rewriteContext.rewrite(correlatedJoinNode.getSubquery(), ImmutableSet.builder().addAll(extractUnique).addAll(rewriteContext.get()).build());
            if (Sets.intersection(ImmutableSet.copyOf(rewrite.getOutputSymbols()), rewriteContext.get()).isEmpty()) {
                if (correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER && QueryCardinalityUtil.isScalar(rewrite) && correlatedJoinNode.getFilter().equals(BooleanLiteral.TRUE_LITERAL)) {
                    return rewriteContext.rewrite(correlatedJoinNode.getInput(), rewriteContext.get());
                }
                if (correlatedJoinNode.getType() == CorrelatedJoinNode.Type.LEFT && QueryCardinalityUtil.isAtMostScalar(rewrite)) {
                    return rewriteContext.rewrite(correlatedJoinNode.getInput(), rewriteContext.get());
                }
            }
            Set<Symbol> extractUnique2 = SymbolsExtractor.extractUnique(rewrite);
            Stream<Symbol> stream = correlatedJoinNode.getCorrelation().stream();
            Objects.requireNonNull(extractUnique2);
            List list = (List) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableList.toImmutableList());
            ImmutableSet build = ImmutableSet.builder().addAll(list).addAll(rewriteContext.get()).build();
            PlanNode rewrite2 = rewriteContext.rewrite(correlatedJoinNode.getInput(), ImmutableSet.builder().addAll(build).addAll(extractUnique).build());
            if (Sets.intersection(ImmutableSet.copyOf(rewrite2.getOutputSymbols()), build).isEmpty()) {
                if (correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER && QueryCardinalityUtil.isScalar(rewrite2) && correlatedJoinNode.getFilter().equals(BooleanLiteral.TRUE_LITERAL)) {
                    return rewrite;
                }
                if (correlatedJoinNode.getType() == CorrelatedJoinNode.Type.RIGHT && QueryCardinalityUtil.isAtMostScalar(rewrite2)) {
                    return rewrite;
                }
            }
            return new CorrelatedJoinNode(correlatedJoinNode.getId(), rewrite2, rewrite, list, correlatedJoinNode.getType(), correlatedJoinNode.getFilter(), correlatedJoinNode.getOriginSubquery());
        }
    }

    public PruneUnreferencedOutputs(Metadata metadata, TypeAnalyzer typeAnalyzer) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override // io.prestosql.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        Objects.requireNonNull(planNode, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(typeProvider, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(this.metadata, typeProvider, this.typeAnalyzer, symbolAllocator, session), planNode, ImmutableSet.of());
    }
}
