package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.sql.DynamicFilters;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.plan.ChildReplacer;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.LogicalBinaryExpression;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.class */
public class RemoveUnsupportedDynamicFilters implements PlanOptimizer {
    private final Metadata metadata;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters$PlanWithConsumedDynamicFilters.class */
    public static class PlanWithConsumedDynamicFilters {
        private final PlanNode node;
        private final Set<String> consumedDynamicFilterIds;

        PlanWithConsumedDynamicFilters(PlanNode planNode, Set<String> set) {
            this.node = planNode;
            this.consumedDynamicFilterIds = ImmutableSet.copyOf(set);
        }

        PlanNode getNode() {
            return this.node;
        }

        Set<String> getConsumedDynamicFilterIds() {
            return this.consumedDynamicFilterIds;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters$Rewriter.class */
    public class Rewriter extends PlanVisitor<PlanWithConsumedDynamicFilters, Set<String>> {
        private Rewriter() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanWithConsumedDynamicFilters visitPlan(PlanNode planNode, Set<String> set) {
            List list = (List) planNode.getSources().stream().map(planNode2 -> {
                return (PlanWithConsumedDynamicFilters) planNode2.accept(this, set);
            }).collect(ImmutableList.toImmutableList());
            return new PlanWithConsumedDynamicFilters(ChildReplacer.replaceChildren(planNode, (List) list.stream().map((v0) -> {
                return v0.getNode();
            }).collect(Collectors.toList())), (Set) list.stream().map((v0) -> {
                return v0.getConsumedDynamicFilterIds();
            }).flatMap((v0) -> {
                return v0.stream();
            }).collect(ImmutableSet.toImmutableSet()));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanWithConsumedDynamicFilters visitJoin(JoinNode joinNode, Set<String> set) {
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) joinNode.getLeft().accept(this, ImmutableSet.builder().addAll(joinNode.getDynamicFilters().keySet()).addAll(set).build());
            Set<String> consumedDynamicFilterIds = planWithConsumedDynamicFilters.getConsumedDynamicFilterIds();
            Map map = (Map) joinNode.getDynamicFilters().entrySet().stream().filter(entry -> {
                return consumedDynamicFilterIds.contains(entry.getKey());
            }).collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, (v0) -> {
                return v0.getValue();
            }));
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters2 = (PlanWithConsumedDynamicFilters) joinNode.getRight().accept(this, set);
            HashSet hashSet = new HashSet(planWithConsumedDynamicFilters2.getConsumedDynamicFilterIds());
            hashSet.addAll(consumedDynamicFilterIds);
            hashSet.removeAll(map.keySet());
            PlanNode node = planWithConsumedDynamicFilters.getNode();
            PlanNode node2 = planWithConsumedDynamicFilters2.getNode();
            return (node.equals(joinNode.getLeft()) && node2.equals(joinNode.getRight()) && map.equals(joinNode.getDynamicFilters())) ? new PlanWithConsumedDynamicFilters(joinNode, ImmutableSet.copyOf(hashSet)) : new PlanWithConsumedDynamicFilters(new JoinNode(joinNode.getId(), joinNode.getType(), node, node2, joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), map), ImmutableSet.copyOf(hashSet));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanWithConsumedDynamicFilters visitFilter(FilterNode filterNode, Set<String> set) {
            PlanWithConsumedDynamicFilters planWithConsumedDynamicFilters = (PlanWithConsumedDynamicFilters) filterNode.getSource().accept(this, set);
            Expression predicate = filterNode.getPredicate();
            ImmutableSet.Builder<String> addAll = ImmutableSet.builder().addAll(planWithConsumedDynamicFilters.getConsumedDynamicFilterIds());
            PlanNode node = planWithConsumedDynamicFilters.getNode();
            Expression removeDynamicFilters = node instanceof TableScanNode ? removeDynamicFilters(predicate, set, addAll) : removeAllDynamicFilters(predicate);
            return BooleanLiteral.TRUE_LITERAL.equals(removeDynamicFilters) ? new PlanWithConsumedDynamicFilters(node, addAll.build()) : (predicate.equals(removeDynamicFilters) && node == filterNode.getSource()) ? new PlanWithConsumedDynamicFilters(filterNode, addAll.build()) : new PlanWithConsumedDynamicFilters(new FilterNode(filterNode.getId(), node, removeDynamicFilters), addAll.build());
        }

        private Expression removeDynamicFilters(Expression expression, Set<String> set, ImmutableSet.Builder<String> builder) {
            return ExpressionUtils.combineConjuncts(RemoveUnsupportedDynamicFilters.this.metadata, (Collection<Expression>) ExpressionUtils.extractConjuncts(expression).stream().map(this::removeNestedDynamicFilters).filter(expression2 -> {
                return ((Boolean) DynamicFilters.getDescriptor(expression2).map(descriptor -> {
                    if (!set.contains(descriptor.getId())) {
                        return false;
                    }
                    builder.add(descriptor.getId());
                    return true;
                }).orElse(true)).booleanValue();
            }).collect(ImmutableList.toImmutableList()));
        }

        private Expression removeAllDynamicFilters(Expression expression) {
            Expression removeNestedDynamicFilters = removeNestedDynamicFilters(expression);
            DynamicFilters.ExtractResult extractDynamicFilters = DynamicFilters.extractDynamicFilters(removeNestedDynamicFilters);
            return extractDynamicFilters.getDynamicConjuncts().isEmpty() ? removeNestedDynamicFilters : ExpressionUtils.combineConjuncts(RemoveUnsupportedDynamicFilters.this.metadata, extractDynamicFilters.getStaticConjuncts());
        }

        private Expression removeNestedDynamicFilters(Expression expression) {
            return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() { // from class: io.prestosql.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters.Rewriter.1
                public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression logicalBinaryExpression, Void r6, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
                    LogicalBinaryExpression logicalBinaryExpression2 = (LogicalBinaryExpression) expressionTreeRewriter.defaultRewrite(logicalBinaryExpression, r6);
                    boolean z = logicalBinaryExpression != logicalBinaryExpression2;
                    ImmutableList.Builder builder = ImmutableList.builder();
                    if (DynamicFilters.isDynamicFilter(logicalBinaryExpression2.getLeft())) {
                        builder.add(BooleanLiteral.TRUE_LITERAL);
                        z = true;
                    } else {
                        builder.add(logicalBinaryExpression2.getLeft());
                    }
                    if (DynamicFilters.isDynamicFilter(logicalBinaryExpression2.getRight())) {
                        builder.add(BooleanLiteral.TRUE_LITERAL);
                        z = true;
                    } else {
                        builder.add(logicalBinaryExpression2.getRight());
                    }
                    return !z ? logicalBinaryExpression : ExpressionUtils.combinePredicates(RemoveUnsupportedDynamicFilters.this.metadata, logicalBinaryExpression.getOperator(), (Collection<Expression>) builder.build());
                }

                public /* bridge */ /* synthetic */ Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression logicalBinaryExpression, Object obj, ExpressionTreeRewriter expressionTreeRewriter) {
                    return rewriteLogicalBinaryExpression(logicalBinaryExpression, (Void) obj, (ExpressionTreeRewriter<Void>) expressionTreeRewriter);
                }
            }, expression);
        }
    }

    public RemoveUnsupportedDynamicFilters(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // io.prestosql.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        return ((PlanWithConsumedDynamicFilters) planNode.accept(new Rewriter(), ImmutableSet.of())).getNode();
    }
}
