package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.PartitioningScheme;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.sql.planner.plan.StatisticAggregations;
import io.prestosql.sql.planner.plan.StatisticAggregationsDescriptor;
import io.prestosql.sql.planner.plan.StatisticsWriterNode;
import io.prestosql.sql.planner.plan.TableFinishNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.SymbolReference;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/SymbolMapper.class */
public class SymbolMapper {
    private final Map<Symbol, Symbol> mapping;

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/SymbolMapper$Builder.class */
    public static class Builder {
        private final ImmutableMap.Builder<Symbol, Symbol> mappings = ImmutableMap.builder();

        public SymbolMapper build() {
            return new SymbolMapper(this.mappings.build());
        }

        public void put(Symbol symbol, Symbol symbol2) {
            this.mappings.put(symbol, symbol2);
        }
    }

    public SymbolMapper(Map<Symbol, Symbol> map) {
        this.mapping = ImmutableMap.copyOf((Map) Objects.requireNonNull(map, "mapping is null"));
    }

    public Symbol map(Symbol symbol) {
        Symbol symbol2;
        Symbol symbol3 = symbol;
        while (true) {
            symbol2 = symbol3;
            if (!this.mapping.containsKey(symbol2) || this.mapping.get(symbol2).equals(symbol2)) {
                break;
            }
            symbol3 = this.mapping.get(symbol2);
        }
        return symbol2;
    }

    public Expression map(Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() { // from class: io.prestosql.sql.planner.optimizations.SymbolMapper.1
            public Expression rewriteSymbolReference(SymbolReference symbolReference, Void r5, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
                return SymbolMapper.this.map(Symbol.from(symbolReference)).toSymbolReference();
            }

            public /* bridge */ /* synthetic */ Expression rewriteSymbolReference(SymbolReference symbolReference, Object obj, ExpressionTreeRewriter expressionTreeRewriter) {
                return rewriteSymbolReference(symbolReference, (Void) obj, (ExpressionTreeRewriter<Void>) expressionTreeRewriter);
            }
        }, expression);
    }

    public AggregationNode map(AggregationNode aggregationNode, PlanNode planNode) {
        return map(aggregationNode, planNode, aggregationNode.getId());
    }

    public AggregationNode map(AggregationNode aggregationNode, PlanNode planNode, PlanNodeIdAllocator planNodeIdAllocator) {
        return map(aggregationNode, planNode, planNodeIdAllocator.getNextId());
    }

    private AggregationNode map(AggregationNode aggregationNode, PlanNode planNode, PlanNodeId planNodeId) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            builder.put(map(entry.getKey()), map(entry.getValue()));
        }
        return new AggregationNode(planNodeId, planNode, builder.build(), AggregationNode.groupingSets(mapAndDistinct(aggregationNode.getGroupingKeys()), aggregationNode.getGroupingSetCount(), aggregationNode.getGlobalGroupingSets()), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol().map(this::map), aggregationNode.getGroupIdSymbol().map(this::map));
    }

    private AggregationNode.Aggregation map(AggregationNode.Aggregation aggregation) {
        return new AggregationNode.Aggregation(map((Expression) aggregation.getCall()), aggregation.getSignature(), aggregation.getMask().map(this::map));
    }

    public TopNNode map(TopNNode topNNode, PlanNode planNode, PlanNodeId planNodeId) {
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        HashSet hashSet = new HashSet(topNNode.getOrderingScheme().getOrderBy().size());
        for (Symbol symbol : topNNode.getOrderingScheme().getOrderBy()) {
            Symbol map = map(symbol);
            if (hashSet.add(map)) {
                hashSet.add(map);
                builder.add(map);
                builder2.put(map, topNNode.getOrderingScheme().getOrdering(symbol));
            }
        }
        return new TopNNode(planNodeId, planNode, topNNode.getCount(), new OrderingScheme(builder.build(), builder2.build()), topNNode.getStep());
    }

    public TableWriterNode map(TableWriterNode tableWriterNode, PlanNode planNode) {
        return map(tableWriterNode, planNode, tableWriterNode.getId());
    }

    public TableWriterNode map(TableWriterNode tableWriterNode, PlanNode planNode, PlanNodeId planNodeId) {
        return new TableWriterNode(planNodeId, planNode, tableWriterNode.getTarget(), map(tableWriterNode.getRowCountSymbol()), map(tableWriterNode.getFragmentSymbol()), (ImmutableList) tableWriterNode.getColumns().stream().map(this::map).collect(ImmutableList.toImmutableList()), tableWriterNode.getColumnNames(), tableWriterNode.getPartitioningScheme().map(partitioningScheme -> {
            return canonicalize(partitioningScheme, planNode);
        }), tableWriterNode.getStatisticsAggregation().map(this::map), tableWriterNode.getStatisticsAggregationDescriptor().map(this::map));
    }

    public StatisticsWriterNode map(StatisticsWriterNode statisticsWriterNode, PlanNode planNode) {
        return new StatisticsWriterNode(statisticsWriterNode.getId(), planNode, statisticsWriterNode.getTarget(), statisticsWriterNode.getRowCountSymbol(), statisticsWriterNode.isRowCountEnabled(), statisticsWriterNode.getDescriptor().map(this::map));
    }

    public TableFinishNode map(TableFinishNode tableFinishNode, PlanNode planNode) {
        return new TableFinishNode(tableFinishNode.getId(), planNode, tableFinishNode.getTarget(), map(tableFinishNode.getRowCountSymbol()), tableFinishNode.getStatisticsAggregation().map(this::map), tableFinishNode.getStatisticsAggregationDescriptor().map(statisticAggregationsDescriptor -> {
            return statisticAggregationsDescriptor.map(this::map);
        }));
    }

    private PartitioningScheme canonicalize(PartitioningScheme partitioningScheme, PlanNode planNode) {
        return new PartitioningScheme(partitioningScheme.getPartitioning().translate(this::map), mapAndDistinct(planNode.getOutputSymbols()), partitioningScheme.getHashColumn().map(this::map), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition());
    }

    private StatisticAggregations map(StatisticAggregations statisticAggregations) {
        return new StatisticAggregations((Map) statisticAggregations.getAggregations().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return map((Symbol) entry.getKey());
        }, entry2 -> {
            return map((AggregationNode.Aggregation) entry2.getValue());
        })), mapAndDistinct(statisticAggregations.getGroupingSymbols()));
    }

    private StatisticAggregationsDescriptor<Symbol> map(StatisticAggregationsDescriptor<Symbol> statisticAggregationsDescriptor) {
        return statisticAggregationsDescriptor.map(this::map);
    }

    private List<Symbol> map(List<Symbol> list) {
        return (List) list.stream().map(this::map).collect(ImmutableList.toImmutableList());
    }

    private List<Symbol> mapAndDistinct(List<Symbol> list) {
        HashSet hashSet = new HashSet();
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<Symbol> it = list.iterator();
        while (it.hasNext()) {
            Symbol map = map(it.next());
            if (hashSet.add(map)) {
                builder.add(map);
            }
        }
        return builder.build();
    }

    public static Builder builder() {
        return new Builder();
    }
}
