package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/PruneSpatialJoinChildrenColumns.class */
public class PruneSpatialJoinChildrenColumns implements Rule<SpatialJoinNode> {
    private static final Pattern<SpatialJoinNode> PATTERN = Patterns.spatialJoin();

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<SpatialJoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(SpatialJoinNode spatialJoinNode, Captures captures, Rule.Context context) {
        ImmutableSet build = ImmutableSet.builder().addAll(spatialJoinNode.getOutputSymbols()).addAll(SymbolsExtractor.extractUnique(spatialJoinNode.getFilter())).build();
        ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(build);
        Optional<Symbol> leftPartitionSymbol = spatialJoinNode.getLeftPartitionSymbol();
        Objects.requireNonNull(addAll);
        leftPartitionSymbol.ifPresent((v1) -> {
            r1.add(v1);
        });
        ImmutableSet.Builder addAll2 = ImmutableSet.builder().addAll(build);
        Optional<Symbol> rightPartitionSymbol = spatialJoinNode.getRightPartitionSymbol();
        Objects.requireNonNull(addAll2);
        rightPartitionSymbol.ifPresent((v1) -> {
            r1.add(v1);
        });
        return (Rule.Result) Util.restrictChildOutputs(context.getIdAllocator(), spatialJoinNode, addAll.build(), addAll2.build()).map(Rule.Result::ofPlanNode).orElse(Rule.Result.empty());
    }
}
