package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.PartitioningScheme;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.Patterns;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.class */
public final class PushRemoteExchangeThroughAssignUniqueId implements Rule<ExchangeNode> {
    private static final Capture<AssignUniqueId> ASSIGN_UNIQUE_ID = Capture.newCapture();
    private static final Pattern<ExchangeNode> PATTERN = Patterns.exchange().matching(exchangeNode -> {
        return exchangeNode.getScope() == ExchangeNode.Scope.REMOTE;
    }).matching(exchangeNode2 -> {
        return exchangeNode2.getType() != ExchangeNode.Type.REPLICATE;
    }).with(Patterns.source().matching(Patterns.assignUniqueId().capturedAs(ASSIGN_UNIQUE_ID)));

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<ExchangeNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(ExchangeNode exchangeNode, Captures captures, Rule.Context context) {
        Preconditions.checkArgument(exchangeNode.getOrderingScheme().isEmpty(), "Merge exchange over AssignUniqueId not supported");
        AssignUniqueId assignUniqueId = (AssignUniqueId) captures.get(ASSIGN_UNIQUE_ID);
        PartitioningScheme partitioningScheme = exchangeNode.getPartitioningScheme();
        return partitioningScheme.getPartitioning().getColumns().contains(assignUniqueId.getIdColumn()) ? Rule.Result.empty() : Rule.Result.ofPlanNode(new AssignUniqueId(assignUniqueId.getId(), new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getScope(), new PartitioningScheme(partitioningScheme.getPartitioning(), removeSymbol(partitioningScheme.getOutputLayout(), assignUniqueId.getIdColumn()), partitioningScheme.getHashColumn(), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition()), ImmutableList.of(assignUniqueId.getSource()), ImmutableList.of(removeSymbol((List) Iterables.getOnlyElement(exchangeNode.getInputs()), assignUniqueId.getIdColumn())), Optional.empty()), assignUniqueId.getIdColumn()));
    }

    private static List<Symbol> removeSymbol(List<Symbol> list, Symbol symbol) {
        return (List) list.stream().filter(symbol2 -> {
            return !symbol.equals(symbol2);
        }).collect(ImmutableList.toImmutableList());
    }
}
