/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PartitionFunctionBinding;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ChildReplacer;
import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.UnnestNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.QualifiedNameReference;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

public class HashGenerationOptimizer
extends PlanOptimizer {
    public static final int INITIAL_HASH_VALUE = 0;
    private static final String HASH_CODE = FunctionRegistry.mangleOperatorName("HASH_CODE");

    @Override
    public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(idAllocator, "idAllocator is null");
        if (SystemSessionProperties.isOptimizeHashGenerationEnabled(session)) {
            PlanWithProperties result = plan.accept(new Rewriter(idAllocator, symbolAllocator, types), new HashComputationSet());
            return result.getNode();
        }
        return plan;
    }

    public static Optional<HashComputation> computeHash(Iterable<Symbol> fields) {
        Objects.requireNonNull(fields, "fields is null");
        ImmutableList symbols = ImmutableList.copyOf(fields);
        if (symbols.isEmpty()) {
            return Optional.empty();
        }
        return Optional.of(new HashComputation(fields));
    }

    private static Map<Symbol, Symbol> computeIdentityTranslations(Map<Symbol, Expression> assignments) {
        HashMap<Symbol, Symbol> outputToInput = new HashMap<Symbol, Symbol>();
        for (Map.Entry<Symbol, Expression> assignment : assignments.entrySet()) {
            if (!(assignment.getValue() instanceof QualifiedNameReference)) continue;
            outputToInput.put(assignment.getKey(), Symbol.fromQualifiedName(((QualifiedNameReference)assignment.getValue()).getName()));
        }
        return outputToInput;
    }

    private static class PlanWithProperties {
        private final PlanNode node;
        private final BiMap<HashComputation, Symbol> hashSymbols;

        public PlanWithProperties(PlanNode node, Map<HashComputation, Symbol> hashSymbols) {
            this.node = Objects.requireNonNull(node, "node is null");
            this.hashSymbols = ImmutableBiMap.copyOf(Objects.requireNonNull(hashSymbols, "hashSymbols is null"));
        }

        public PlanNode getNode() {
            return this.node;
        }

        public BiMap<HashComputation, Symbol> getHashSymbols() {
            return this.hashSymbols;
        }

        public Symbol getRequiredHashSymbol(HashComputation hash) {
            Symbol hashSymbol = (Symbol)this.hashSymbols.get((Object)hash);
            Objects.requireNonNull(hashSymbol, () -> "No hash symbol generated for " + hash);
            return hashSymbol;
        }
    }

    private static class HashComputation {
        private final List<Symbol> fields;

        private HashComputation(Iterable<Symbol> fields) {
            Objects.requireNonNull(fields, "fields is null");
            this.fields = ImmutableList.copyOf(fields);
            Preconditions.checkArgument((!this.fields.isEmpty() ? 1 : 0) != 0, (Object)"fields can not be empty");
        }

        public List<Symbol> getFields() {
            return this.fields;
        }

        public Optional<HashComputation> translate(Function<Symbol, Optional<Symbol>> translator) {
            ImmutableList.Builder newSymbols = ImmutableList.builder();
            for (Symbol field : this.fields) {
                Optional<Symbol> newSymbol = translator.apply(field);
                if (!newSymbol.isPresent()) {
                    return Optional.empty();
                }
                newSymbols.add((Object)newSymbol.get());
            }
            return HashGenerationOptimizer.computeHash((Iterable<Symbol>)newSymbols.build());
        }

        public boolean canComputeWith(Set<Symbol> availableFields) {
            return availableFields.containsAll(this.fields);
        }

        private Expression getHashExpression() {
            LongLiteral hashExpression = new LongLiteral(String.valueOf(0));
            for (Symbol field : this.fields) {
                hashExpression = HashComputation.getHashFunctionCall((Expression)hashExpression, field);
            }
            return hashExpression;
        }

        private static Expression getHashFunctionCall(Expression previousHashValue, Symbol symbol) {
            FunctionCall functionCall = new FunctionCall(QualifiedName.of((String)HASH_CODE, (String[])new String[0]), Optional.empty(), false, (List)ImmutableList.of((Object)new QualifiedNameReference(symbol.toQualifiedName())));
            ImmutableList arguments = ImmutableList.of((Object)previousHashValue, (Object)HashComputation.orNullHashCode((Expression)functionCall));
            return new FunctionCall(QualifiedName.of((String)"combine_hash", (String[])new String[0]), (List)arguments);
        }

        private static Expression orNullHashCode(Expression expression) {
            return new CoalesceExpression(new Expression[]{expression, new LongLiteral(String.valueOf(0))});
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            HashComputation that = (HashComputation)o;
            return Objects.equals(this.fields, that.fields);
        }

        public int hashCode() {
            return Objects.hash(this.fields);
        }

        public String toString() {
            return MoreObjects.toStringHelper((Object)this).add("fields", this.fields).toString();
        }
    }

    private static class HashComputationSet {
        private final Set<HashComputation> hashes;

        public HashComputationSet() {
            this.hashes = ImmutableSet.of();
        }

        public HashComputationSet(Optional<HashComputation> hash) {
            Objects.requireNonNull(hash, "hash is null");
            this.hashes = hash.isPresent() ? ImmutableSet.of((Object)hash.get()) : ImmutableSet.of();
        }

        private HashComputationSet(Iterable<HashComputation> hashes) {
            Objects.requireNonNull(hashes, "hashes is null");
            this.hashes = ImmutableSet.copyOf(hashes);
        }

        public Set<HashComputation> getHashes() {
            return this.hashes;
        }

        public HashComputationSet pruneSymbols(List<Symbol> symbols) {
            ImmutableSet uniqueSymbols = ImmutableSet.copyOf(symbols);
            return new HashComputationSet((Iterable)this.hashes.stream().filter(arg_0 -> HashComputationSet.lambda$pruneSymbols$0((Set)uniqueSymbols, arg_0)).collect(ImmutableCollectors.toImmutableSet()));
        }

        public HashComputationSet translate(Function<Symbol, Optional<Symbol>> translator) {
            Set newHashes = (Set)this.hashes.stream().map(hash -> hash.translate(translator)).filter(Optional::isPresent).map(Optional::get).collect(ImmutableCollectors.toImmutableSet());
            return new HashComputationSet(newHashes);
        }

        public HashComputationSet withHashComputation(PlanNode node, Optional<HashComputation> hashComputation) {
            return this.pruneSymbols(node.getOutputSymbols()).withHashComputation(hashComputation);
        }

        public HashComputationSet withHashComputation(Optional<HashComputation> hashComputation) {
            if (!hashComputation.isPresent()) {
                return this;
            }
            return new HashComputationSet((Iterable<HashComputation>)ImmutableSet.builder().addAll(this.hashes).add((Object)hashComputation.get()).build());
        }

        private static /* synthetic */ boolean lambda$pruneSymbols$0(Set uniqueSymbols, HashComputation hash) {
            return hash.canComputeWith(uniqueSymbols);
        }
    }

    private static class Rewriter
    extends PlanVisitor<HashComputationSet, PlanWithProperties> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Map<Symbol, Type> types;

        private Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Map<Symbol, Type> types) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.types = Objects.requireNonNull(types, "types is null");
        }

        @Override
        protected PlanWithProperties visitPlan(PlanNode node, HashComputationSet parentPreference) {
            return this.planSimpleNodeWithProperties(node, parentPreference);
        }

        @Override
        public PlanWithProperties visitEnforceSingleRow(EnforceSingleRowNode node, HashComputationSet parentPreference) {
            return this.planSimpleNodeWithProperties(node, new HashComputationSet(), true);
        }

        @Override
        public PlanWithProperties visitAggregation(AggregationNode node, HashComputationSet parentPreference) {
            Optional<HashComputation> groupByHash = Optional.empty();
            if (!this.canSkipHashGeneration(node.getGroupBy())) {
                groupByHash = HashGenerationOptimizer.computeHash(node.getGroupBy());
            }
            HashComputationSet requiredHashes = new HashComputationSet(groupByHash);
            PlanWithProperties child = this.planAndEnforce(node.getSource(), requiredHashes, false, requiredHashes);
            Optional<Symbol> hashSymbol = groupByHash.map(child::getRequiredHashSymbol);
            return new PlanWithProperties(new AggregationNode(this.idAllocator.getNextId(), child.getNode(), node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getGroupingSets(), node.getStep(), node.getSampleWeight(), node.getConfidence(), hashSymbol), (Map<HashComputation, Symbol>)(hashSymbol.isPresent() ? ImmutableMap.of((Object)groupByHash.get(), (Object)hashSymbol.get()) : ImmutableMap.of()));
        }

        private boolean canSkipHashGeneration(List<Symbol> partitionSymbols) {
            return partitionSymbols.isEmpty() || partitionSymbols.size() == 1 && this.types.get(Iterables.getOnlyElement(partitionSymbols)).equals(BigintType.BIGINT);
        }

        @Override
        public PlanWithProperties visitGroupId(GroupIdNode node, HashComputationSet parentPreference) {
            return this.planSimpleNodeWithProperties(node, parentPreference.pruneSymbols(node.getSource().getOutputSymbols()));
        }

        @Override
        public PlanWithProperties visitDistinctLimit(DistinctLimitNode node, HashComputationSet parentPreference) {
            if (!this.canSkipHashGeneration(node.getDistinctSymbols())) {
                return this.planSimpleNodeWithProperties(node, parentPreference);
            }
            Optional<HashComputation> hashComputation = HashGenerationOptimizer.computeHash(node.getDistinctSymbols());
            PlanWithProperties child = this.planAndEnforce(node.getSource(), new HashComputationSet(hashComputation), false, parentPreference.withHashComputation(node, hashComputation));
            Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get());
            return new PlanWithProperties(new DistinctLimitNode(this.idAllocator.getNextId(), child.getNode(), node.getLimit(), Optional.of(hashSymbol)), (Map<HashComputation, Symbol>)child.getHashSymbols());
        }

        @Override
        public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, HashComputationSet parentPreference) {
            if (!this.canSkipHashGeneration(node.getDistinctSymbols())) {
                return this.planSimpleNodeWithProperties(node, parentPreference);
            }
            Optional<HashComputation> hashComputation = HashGenerationOptimizer.computeHash(node.getDistinctSymbols());
            PlanWithProperties child = this.planAndEnforce(node.getSource(), new HashComputationSet(hashComputation), false, parentPreference.withHashComputation(node, hashComputation));
            Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get());
            return new PlanWithProperties(new MarkDistinctNode(this.idAllocator.getNextId(), child.getNode(), node.getMarkerSymbol(), node.getDistinctSymbols(), Optional.of(hashSymbol)), (Map<HashComputation, Symbol>)child.getHashSymbols());
        }

        @Override
        public PlanWithProperties visitRowNumber(RowNumberNode node, HashComputationSet parentPreference) {
            if (node.getPartitionBy().isEmpty()) {
                return this.planSimpleNodeWithProperties(node, parentPreference);
            }
            Optional<HashComputation> hashComputation = HashGenerationOptimizer.computeHash(node.getPartitionBy());
            PlanWithProperties child = this.planAndEnforce(node.getSource(), new HashComputationSet(hashComputation), false, parentPreference.withHashComputation(node, hashComputation));
            Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get());
            return new PlanWithProperties(new RowNumberNode(this.idAllocator.getNextId(), child.getNode(), node.getPartitionBy(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), Optional.of(hashSymbol)), (Map<HashComputation, Symbol>)child.getHashSymbols());
        }

        @Override
        public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputationSet parentPreference) {
            if (node.getPartitionBy().isEmpty()) {
                return this.planSimpleNodeWithProperties(node, parentPreference);
            }
            Optional<HashComputation> hashComputation = HashGenerationOptimizer.computeHash(node.getPartitionBy());
            PlanWithProperties child = this.planAndEnforce(node.getSource(), new HashComputationSet(hashComputation), false, parentPreference.withHashComputation(node, hashComputation));
            Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get());
            return new PlanWithProperties(new TopNRowNumberNode(this.idAllocator.getNextId(), child.getNode(), node.getPartitionBy(), node.getOrderBy(), node.getOrderings(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.isPartial(), Optional.of(hashSymbol)), (Map<HashComputation, Symbol>)child.getHashSymbols());
        }

        @Override
        public PlanWithProperties visitJoin(JoinNode node, HashComputationSet parentPreference) {
            List<JoinNode.EquiJoinClause> clauses = node.getCriteria();
            if (clauses.isEmpty()) {
                PlanWithProperties left = this.planAndEnforce(node.getLeft(), new HashComputationSet(), true, new HashComputationSet());
                PlanWithProperties right = this.planAndEnforce(node.getRight(), new HashComputationSet(), true, new HashComputationSet());
                HashMap<HashComputation, Symbol> allHashSymbols = new HashMap<HashComputation, Symbol>();
                allHashSymbols.putAll((Map<HashComputation, Symbol>)right.getHashSymbols());
                allHashSymbols.putAll((Map<HashComputation, Symbol>)left.getHashSymbols());
                return new PlanWithProperties(new JoinNode(this.idAllocator.getNextId(), node.getType(), left.getNode(), right.getNode(), node.getCriteria(), Optional.empty(), Optional.empty()), allHashSymbols);
            }
            Optional<HashComputation> leftHashComputation = HashGenerationOptimizer.computeHash(Lists.transform(clauses, JoinNode.EquiJoinClause::getLeft));
            PlanWithProperties left = this.planAndEnforce(node.getLeft(), new HashComputationSet(leftHashComputation), true, new HashComputationSet(leftHashComputation));
            Symbol leftHashSymbol = left.getRequiredHashSymbol(leftHashComputation.get());
            Optional<HashComputation> rightHashComputation = HashGenerationOptimizer.computeHash(Lists.transform(clauses, JoinNode.EquiJoinClause::getRight));
            PlanWithProperties right = this.planAndEnforce(node.getRight(), new HashComputationSet(rightHashComputation), true, new HashComputationSet(rightHashComputation));
            Symbol rightHashSymbol = right.getRequiredHashSymbol(rightHashComputation.get());
            HashMap<HashComputation, Symbol> allHashSymbols = new HashMap<HashComputation, Symbol>();
            if (node.getType() == JoinNode.Type.INNER || node.getType() == JoinNode.Type.LEFT) {
                allHashSymbols.putAll((Map<HashComputation, Symbol>)left.getHashSymbols());
            }
            if (node.getType() == JoinNode.Type.INNER || node.getType() == JoinNode.Type.RIGHT) {
                allHashSymbols.putAll((Map<HashComputation, Symbol>)right.getHashSymbols());
            }
            return new PlanWithProperties(new JoinNode(this.idAllocator.getNextId(), node.getType(), left.getNode(), right.getNode(), node.getCriteria(), Optional.of(leftHashSymbol), Optional.of(rightHashSymbol)), allHashSymbols);
        }

        @Override
        public PlanWithProperties visitSemiJoin(SemiJoinNode node, HashComputationSet parentPreference) {
            Optional<HashComputation> sourceHashComputation = HashGenerationOptimizer.computeHash((Iterable<Symbol>)ImmutableList.of((Object)node.getSourceJoinSymbol()));
            PlanWithProperties source = this.planAndEnforce(node.getSource(), new HashComputationSet(sourceHashComputation), true, new HashComputationSet(sourceHashComputation));
            Symbol sourceHashSymbol = source.getRequiredHashSymbol(sourceHashComputation.get());
            Optional<HashComputation> filterHashComputation = HashGenerationOptimizer.computeHash((Iterable<Symbol>)ImmutableList.of((Object)node.getFilteringSourceJoinSymbol()));
            HashComputationSet requiredHashes = new HashComputationSet(filterHashComputation);
            PlanWithProperties filteringSource = this.planAndEnforce(node.getFilteringSource(), requiredHashes, true, requiredHashes);
            Symbol filteringSourceHashSymbol = filteringSource.getRequiredHashSymbol(filterHashComputation.get());
            return new PlanWithProperties(new SemiJoinNode(this.idAllocator.getNextId(), source.getNode(), filteringSource.getNode(), node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), Optional.of(sourceHashSymbol), Optional.of(filteringSourceHashSymbol)), (Map<HashComputation, Symbol>)source.getHashSymbols());
        }

        @Override
        public PlanWithProperties visitIndexJoin(IndexJoinNode node, HashComputationSet parentPreference) {
            List<IndexJoinNode.EquiJoinClause> clauses = node.getCriteria();
            Optional<HashComputation> probeHashComputation = HashGenerationOptimizer.computeHash(Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getProbe));
            PlanWithProperties probe = this.planAndEnforce(node.getProbeSource(), new HashComputationSet(probeHashComputation), true, new HashComputationSet(probeHashComputation));
            Symbol probeHashSymbol = probe.getRequiredHashSymbol(probeHashComputation.get());
            Optional<HashComputation> indexHashComputation = HashGenerationOptimizer.computeHash(Lists.transform(clauses, IndexJoinNode.EquiJoinClause::getIndex));
            HashComputationSet requiredHashes = new HashComputationSet(indexHashComputation);
            PlanWithProperties index = this.planAndEnforce(node.getIndexSource(), requiredHashes, true, requiredHashes);
            Symbol indexHashSymbol = index.getRequiredHashSymbol(indexHashComputation.get());
            HashMap<HashComputation, Symbol> allHashSymbols = new HashMap<HashComputation, Symbol>();
            if (node.getType() == IndexJoinNode.Type.INNER) {
                allHashSymbols.putAll((Map<HashComputation, Symbol>)probe.getHashSymbols());
            }
            allHashSymbols.putAll((Map<HashComputation, Symbol>)index.getHashSymbols());
            return new PlanWithProperties(new IndexJoinNode(this.idAllocator.getNextId(), node.getType(), probe.getNode(), index.getNode(), node.getCriteria(), Optional.of(probeHashSymbol), Optional.of(indexHashSymbol)), allHashSymbols);
        }

        @Override
        public PlanWithProperties visitWindow(WindowNode node, HashComputationSet parentPreference) {
            if (node.getPartitionBy().isEmpty()) {
                return this.planSimpleNodeWithProperties(node, parentPreference, true);
            }
            Optional<HashComputation> hashComputation = HashGenerationOptimizer.computeHash(node.getPartitionBy());
            PlanWithProperties child = this.planAndEnforce(node.getSource(), new HashComputationSet(hashComputation), true, parentPreference.withHashComputation(node, hashComputation));
            Symbol hashSymbol = child.getRequiredHashSymbol(hashComputation.get());
            return new PlanWithProperties(new WindowNode(this.idAllocator.getNextId(), child.getNode(), node.getPartitionBy(), node.getOrderBy(), node.getOrderings(), node.getFrame(), node.getWindowFunctions(), node.getSignatures(), Optional.of(hashSymbol), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix()), (Map<HashComputation, Symbol>)child.getHashSymbols());
        }

        @Override
        public PlanWithProperties visitExchange(ExchangeNode node, HashComputationSet parentPreference) {
            HashComputationSet preference = parentPreference.pruneSymbols(node.getOutputSymbols());
            Optional<HashComputation> partitionSymbols = Optional.empty();
            PartitionFunctionBinding partitionFunction = node.getPartitionFunction();
            if (partitionFunction.getPartitioningHandle().equals(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) && partitionFunction.getPartitionFunctionArguments().stream().allMatch(PartitionFunctionBinding.PartitionFunctionArgumentBinding::isVariable)) {
                partitionSymbols = HashGenerationOptimizer.computeHash((Iterable)partitionFunction.getPartitionFunctionArguments().stream().map(PartitionFunctionBinding.PartitionFunctionArgumentBinding::getColumn).collect(ImmutableCollectors.toImmutableList()));
                preference = preference.withHashComputation(partitionSymbols);
            }
            ImmutableList hashSymbolOrder = ImmutableList.copyOf(preference.getHashes());
            HashMap<HashComputation, Symbol> newHashSymbols = new HashMap<HashComputation, Symbol>();
            for (HashComputation preferredHashSymbol : hashSymbolOrder) {
                newHashSymbols.put(preferredHashSymbol, this.symbolAllocator.newHashSymbol());
            }
            partitionFunction = new PartitionFunctionBinding(partitionFunction.getPartitioningHandle(), (List<Symbol>)ImmutableList.builder().addAll(partitionFunction.getOutputLayout()).addAll((Iterable)hashSymbolOrder.stream().map(newHashSymbols::get).collect(ImmutableCollectors.toImmutableList())).build(), partitionFunction.getPartitionFunctionArguments(), partitionSymbols.map(newHashSymbols::get), partitionFunction.isReplicateNulls(), partitionFunction.getBucketToPartition());
            ImmutableList.Builder newInputs = ImmutableList.builder();
            ImmutableList.Builder newSources = ImmutableList.builder();
            for (int sourceId = 0; sourceId < node.getSources().size(); ++sourceId) {
                PlanNode source = node.getSources().get(sourceId);
                List<Symbol> inputSymbols = node.getInputs().get(sourceId);
                HashMap<Symbol, Symbol> outputToInputMap = new HashMap<Symbol, Symbol>();
                for (int symbolId = 0; symbolId < inputSymbols.size(); ++symbolId) {
                    outputToInputMap.put(node.getOutputSymbols().get(symbolId), inputSymbols.get(symbolId));
                }
                Function<Symbol, Optional<Symbol>> outputToInputTranslator = symbol -> Optional.of(outputToInputMap.get(symbol));
                HashComputationSet sourceContext = preference.translate(outputToInputTranslator);
                PlanWithProperties child = this.planAndEnforce(source, sourceContext, true, sourceContext);
                newSources.add((Object)child.getNode());
                ImmutableList.Builder newInputSymbols = ImmutableList.builder();
                newInputSymbols.addAll((Iterable)node.getInputs().get(sourceId));
                for (HashComputation preferredHashSymbol : hashSymbolOrder) {
                    HashComputation hashComputation = preferredHashSymbol.translate(outputToInputTranslator).get();
                    newInputSymbols.add((Object)child.getRequiredHashSymbol(hashComputation));
                }
                newInputs.add((Object)newInputSymbols.build());
            }
            return new PlanWithProperties(new ExchangeNode(this.idAllocator.getNextId(), node.getType(), partitionFunction, (List<PlanNode>)newSources.build(), (List<List<Symbol>>)newInputs.build()), newHashSymbols);
        }

        @Override
        public PlanWithProperties visitUnion(UnionNode node, HashComputationSet parentPreference) {
            HashComputationSet preference = parentPreference.pruneSymbols(node.getOutputSymbols());
            HashMap<HashComputation, Symbol> newHashSymbols = new HashMap<HashComputation, Symbol>();
            for (HashComputation preferredHashSymbol : preference.getHashes()) {
                newHashSymbols.put(preferredHashSymbol, this.symbolAllocator.newHashSymbol());
            }
            ImmutableListMultimap.Builder newSymbolMapping = ImmutableListMultimap.builder();
            newSymbolMapping.putAll(node.getSymbolMapping());
            ImmutableList.Builder newSources = ImmutableList.builder();
            for (int sourceId = 0; sourceId < node.getSources().size(); ++sourceId) {
                HashMap outputToInputMap = new HashMap();
                for (Symbol outputSymbol : node.getOutputSymbols()) {
                    outputToInputMap.put(outputSymbol, node.getSymbolMapping().get((Object)outputSymbol).get(sourceId));
                }
                Function<Symbol, Optional<Symbol>> outputToInputTranslator = symbol -> Optional.of(outputToInputMap.get(symbol));
                HashComputationSet sourcePreference = preference.translate(outputToInputTranslator);
                PlanWithProperties child = this.planAndEnforce(node.getSources().get(sourceId), sourcePreference, true, sourcePreference);
                newSources.add((Object)child.getNode());
                for (Map.Entry entry : newHashSymbols.entrySet()) {
                    HashComputation hashComputation = ((HashComputation)entry.getKey()).translate(outputToInputTranslator).get();
                    newSymbolMapping.put(entry.getValue(), (Object)child.getRequiredHashSymbol(hashComputation));
                }
            }
            return new PlanWithProperties(new UnionNode(this.idAllocator.getNextId(), (List<PlanNode>)newSources.build(), (ListMultimap<Symbol, Symbol>)newSymbolMapping.build(), (List<Symbol>)ImmutableList.copyOf((Collection)newSymbolMapping.build().keySet())), newHashSymbols);
        }

        @Override
        public PlanWithProperties visitProject(ProjectNode node, HashComputationSet parentPreference) {
            Map outputToInputMapping = HashGenerationOptimizer.computeIdentityTranslations(node.getAssignments());
            HashComputationSet sourceContext = parentPreference.translate(symbol -> Optional.ofNullable(outputToInputMapping.get(symbol)));
            PlanWithProperties child = this.plan(node.getSource(), sourceContext);
            HashMap<Symbol, Expression> newAssignments = new HashMap<Symbol, Expression>();
            newAssignments.putAll(node.getAssignments());
            HashMap<HashComputation, Symbol> allHashSymbols = new HashMap<HashComputation, Symbol>();
            for (HashComputation hashComputation : sourceContext.getHashes()) {
                QualifiedNameReference hashExpression;
                Symbol hashSymbol = (Symbol)child.getHashSymbols().get((Object)hashComputation);
                if (hashSymbol == null) {
                    hashSymbol = this.symbolAllocator.newHashSymbol();
                    hashExpression = hashComputation.getHashExpression();
                } else {
                    hashExpression = new QualifiedNameReference(hashSymbol.toQualifiedName());
                }
                newAssignments.put(hashSymbol, (Expression)hashExpression);
                allHashSymbols.put(hashComputation, hashSymbol);
            }
            return new PlanWithProperties(new ProjectNode(this.idAllocator.getNextId(), child.getNode(), newAssignments), allHashSymbols);
        }

        @Override
        public PlanWithProperties visitUnnest(UnnestNode node, HashComputationSet parentPreference) {
            PlanWithProperties child = this.plan(node.getSource(), parentPreference.pruneSymbols(node.getSource().getOutputSymbols()));
            HashMap<HashComputation, Symbol> hashSymbols = new HashMap<HashComputation, Symbol>((Map<HashComputation, Symbol>)child.getHashSymbols());
            hashSymbols.keySet().retainAll(parentPreference.getHashes());
            return new PlanWithProperties(new UnnestNode(this.idAllocator.getNextId(), child.getNode(), (List<Symbol>)ImmutableList.builder().addAll(node.getReplicateSymbols()).addAll(hashSymbols.values()).build(), node.getUnnestSymbols(), node.getOrdinalitySymbol()), hashSymbols);
        }

        private PlanWithProperties planSimpleNodeWithProperties(PlanNode node, HashComputationSet preferredHashes) {
            return this.planSimpleNodeWithProperties(node, preferredHashes, true);
        }

        private PlanWithProperties planSimpleNodeWithProperties(PlanNode node, HashComputationSet preferredHashes, boolean alwaysPruneExtraHashSymbols) {
            if (node.getSources().isEmpty()) {
                return new PlanWithProperties(node, (Map<HashComputation, Symbol>)ImmutableMap.of());
            }
            PlanWithProperties source = this.planAndEnforce((PlanNode)Iterables.getOnlyElement(node.getSources()), new HashComputationSet(), alwaysPruneExtraHashSymbols, preferredHashes);
            PlanNode result = ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source.getNode()));
            HashMap<HashComputation, Symbol> hashSymbols = new HashMap<HashComputation, Symbol>((Map<HashComputation, Symbol>)source.getHashSymbols());
            hashSymbols.values().retainAll(result.getOutputSymbols());
            return new PlanWithProperties(result, hashSymbols);
        }

        private PlanWithProperties planAndEnforce(PlanNode node, HashComputationSet requiredHashes, boolean pruneExtraHashSymbols, HashComputationSet preferredHashes) {
            PlanWithProperties result = this.plan(node, preferredHashes);
            boolean preferenceSatisfied = pruneExtraHashSymbols ? result.getHashSymbols().keySet().equals(requiredHashes.getHashes()) : result.getHashSymbols().keySet().containsAll(requiredHashes.getHashes());
            if (preferenceSatisfied) {
                return result;
            }
            return this.enforce(result, requiredHashes);
        }

        private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashComputationSet requiredHashes) {
            ImmutableMap.Builder assignments = ImmutableMap.builder();
            HashMap<HashComputation, Symbol> outputHashSymbols = new HashMap<HashComputation, Symbol>();
            BiMap resultHashSymbols = planWithProperties.getHashSymbols().inverse();
            for (Symbol symbol : planWithProperties.getNode().getOutputSymbols()) {
                HashComputation partitionSymbols = (HashComputation)resultHashSymbols.get(symbol);
                if (partitionSymbols != null && !requiredHashes.getHashes().contains(partitionSymbols)) continue;
                assignments.put((Object)symbol, (Object)new QualifiedNameReference(symbol.toQualifiedName()));
                if (partitionSymbols == null) continue;
                outputHashSymbols.put(partitionSymbols, symbol);
            }
            for (HashComputation hashComputation : requiredHashes.getHashes()) {
                if (planWithProperties.getHashSymbols().containsKey((Object)hashComputation)) continue;
                Expression hashExpression = hashComputation.getHashExpression();
                Symbol hashSymbol = this.symbolAllocator.newHashSymbol();
                assignments.put((Object)hashSymbol, (Object)hashExpression);
                outputHashSymbols.put(hashComputation, hashSymbol);
            }
            ProjectNode projectNode = new ProjectNode(this.idAllocator.getNextId(), planWithProperties.getNode(), (Map<Symbol, Expression>)assignments.build());
            return new PlanWithProperties(projectNode, outputHashSymbols);
        }

        private PlanWithProperties plan(PlanNode node, HashComputationSet parentPreference) {
            PlanWithProperties result = node.accept(this, parentPreference);
            Preconditions.checkState((boolean)result.getNode().getOutputSymbols().containsAll(result.getHashSymbols().values()), (String)"Node %s declares hash symbols not in the output", (Object[])new Object[]{result.getNode().getClass().getSimpleName()});
            return result;
        }
    }
}

