package io.prestosql.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.EnforceSingleRowNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/PlanNodeDecorrelator.class */
public class PlanNodeDecorrelator {
    private final PlanNodeIdAllocator idAllocator;
    private final Lookup lookup;

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/PlanNodeDecorrelator$DecorrelatedNode.class */
    public static class DecorrelatedNode {
        private final List<Expression> correlatedPredicates;
        private final PlanNode node;

        public DecorrelatedNode(List<Expression> list, PlanNode planNode) {
            Objects.requireNonNull(list, "correlatedPredicates is null");
            this.correlatedPredicates = ImmutableList.copyOf(list);
            this.node = (PlanNode) Objects.requireNonNull(planNode, "node is null");
        }

        public Optional<Expression> getCorrelatedPredicates() {
            return this.correlatedPredicates.isEmpty() ? Optional.empty() : Optional.of(ExpressionUtils.and(this.correlatedPredicates));
        }

        public PlanNode getNode() {
            return this.node;
        }
    }

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/PlanNodeDecorrelator$DecorrelatingVisitor.class */
    private class DecorrelatingVisitor extends PlanVisitor<Optional<DecorrelationResult>, Void> {
        final List<Symbol> correlation;

        DecorrelatingVisitor(List<Symbol> list) {
            this.correlation = (List) Objects.requireNonNull(list, "correlation is null");
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Optional<DecorrelationResult> visitPlan(PlanNode planNode, Void r10) {
            return Optional.of(new DecorrelationResult(planNode, ImmutableSet.of(), ImmutableList.of(), ImmutableMultimap.of(), false));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Optional<DecorrelationResult> visitFilter(FilterNode filterNode, Void r11) {
            Optional of = Optional.of(new DecorrelationResult(filterNode.getSource(), ImmutableSet.of(), ImmutableList.of(), ImmutableMultimap.of(), false));
            if (PlanNodeDecorrelator.this.containsCorrelation(filterNode.getSource(), this.correlation)) {
                of = (Optional) PlanNodeDecorrelator.this.lookup.resolve(filterNode.getSource()).accept(this, null);
            }
            if (!of.isPresent()) {
                return Optional.empty();
            }
            Map map = (Map) ExpressionUtils.extractConjuncts(filterNode.getPredicate()).stream().collect(Collectors.partitioningBy(this::isCorrelated));
            ImmutableList copyOf = ImmutableList.copyOf((Collection) map.get(true));
            ImmutableList copyOf2 = ImmutableList.copyOf((Collection) map.get(false));
            DecorrelationResult decorrelationResult = (DecorrelationResult) of.get();
            return Optional.of(new DecorrelationResult(new FilterNode(PlanNodeDecorrelator.this.idAllocator.getNextId(), decorrelationResult.node, ExpressionUtils.combineConjuncts((Collection<Expression>) copyOf2)), Sets.union(decorrelationResult.symbolsToPropagate, Sets.difference(SymbolsExtractor.extractUnique((Iterable<? extends Expression>) copyOf), ImmutableSet.copyOf(this.correlation))), ImmutableList.builder().addAll(decorrelationResult.correlatedPredicates).addAll(copyOf).build(), ImmutableMultimap.builder().putAll(decorrelationResult.correlatedSymbolsMapping).putAll(extractCorrelatedSymbolsMapping(copyOf)).build(), decorrelationResult.atMostSingleRow));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Optional<DecorrelationResult> visitLimit(LimitNode limitNode, Void r13) {
            Optional<DecorrelationResult> optional = (Optional) PlanNodeDecorrelator.this.lookup.resolve(limitNode.getSource()).accept(this, null);
            if (!optional.isPresent() || limitNode.getCount() == 0) {
                return Optional.empty();
            }
            DecorrelationResult decorrelationResult = optional.get();
            if (decorrelationResult.atMostSingleRow) {
                return optional;
            }
            if (limitNode.getCount() != 1) {
                return Optional.empty();
            }
            Set<Symbol> constantSymbols = decorrelationResult.getConstantSymbols();
            PlanNode planNode = decorrelationResult.node;
            return (constantSymbols.isEmpty() || !constantSymbols.containsAll(planNode.getOutputSymbols())) ? Optional.empty() : Optional.of(new DecorrelationResult(new AggregationNode(PlanNodeDecorrelator.this.idAllocator.getNextId(), planNode, ImmutableMap.of(), AggregationNode.singleGroupingSet(planNode.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), decorrelationResult.symbolsToPropagate, decorrelationResult.correlatedPredicates, decorrelationResult.correlatedSymbolsMapping, true));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Optional<DecorrelationResult> visitEnforceSingleRow(EnforceSingleRowNode enforceSingleRowNode, Void r6) {
            return ((Optional) PlanNodeDecorrelator.this.lookup.resolve(enforceSingleRowNode.getSource()).accept(this, null)).filter(decorrelationResult -> {
                return decorrelationResult.atMostSingleRow;
            });
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Optional<DecorrelationResult> visitAggregation(AggregationNode aggregationNode, Void r13) {
            if (aggregationNode.hasEmptyGroupingSet()) {
                return Optional.empty();
            }
            Optional optional = (Optional) PlanNodeDecorrelator.this.lookup.resolve(aggregationNode.getSource()).accept(this, null);
            if (!optional.isPresent()) {
                return Optional.empty();
            }
            DecorrelationResult decorrelationResult = (DecorrelationResult) optional.get();
            Set<Symbol> constantSymbols = decorrelationResult.getConstantSymbols();
            AggregationNode map = decorrelationResult.getCorrelatedSymbolMapper().map(aggregationNode, decorrelationResult.node);
            ImmutableSet copyOf = ImmutableSet.copyOf(aggregationNode.getGroupingKeys());
            List list = (List) decorrelationResult.symbolsToPropagate.stream().filter(symbol -> {
                return !copyOf.contains(symbol);
            }).collect(ImmutableList.toImmutableList());
            if (!constantSymbols.containsAll(list)) {
                return Optional.empty();
            }
            AggregationNode aggregationNode2 = new AggregationNode(map.getId(), map.getSource(), map.getAggregations(), AggregationNode.singleGroupingSet(ImmutableList.builder().addAll(aggregationNode.getGroupingKeys()).addAll(list).build()), ImmutableList.of(), map.getStep(), map.getHashSymbol(), map.getGroupIdSymbol());
            return Optional.of(new DecorrelationResult(aggregationNode2, decorrelationResult.symbolsToPropagate, decorrelationResult.correlatedPredicates, decorrelationResult.correlatedSymbolsMapping, aggregationNode2.getGroupingSetCount() == 1 && constantSymbols.containsAll(aggregationNode2.getGroupingKeys())));
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Optional<DecorrelationResult> visitProject(ProjectNode projectNode, Void r10) {
            Optional optional = (Optional) PlanNodeDecorrelator.this.lookup.resolve(projectNode.getSource()).accept(this, null);
            if (!optional.isPresent()) {
                return Optional.empty();
            }
            DecorrelationResult decorrelationResult = (DecorrelationResult) optional.get();
            ImmutableSet copyOf = ImmutableSet.copyOf(projectNode.getOutputSymbols());
            return Optional.of(new DecorrelationResult(new ProjectNode(PlanNodeDecorrelator.this.idAllocator.getNextId(), decorrelationResult.node, Assignments.builder().putAll(projectNode.getAssignments()).putIdentities((List) decorrelationResult.symbolsToPropagate.stream().filter(symbol -> {
                return !copyOf.contains(symbol);
            }).collect(ImmutableList.toImmutableList())).build()), decorrelationResult.symbolsToPropagate, decorrelationResult.correlatedPredicates, decorrelationResult.correlatedSymbolsMapping, decorrelationResult.atMostSingleRow));
        }

        private Multimap<Symbol, Symbol> extractCorrelatedSymbolsMapping(List<Expression> list) {
            ImmutableMultimap.Builder builder = ImmutableMultimap.builder();
            Iterator<Expression> it = list.iterator();
            while (it.hasNext()) {
                ComparisonExpression comparisonExpression = (Expression) it.next();
                if (comparisonExpression instanceof ComparisonExpression) {
                    ComparisonExpression comparisonExpression2 = comparisonExpression;
                    if ((comparisonExpression2.getLeft() instanceof SymbolReference) && (comparisonExpression2.getRight() instanceof SymbolReference) && comparisonExpression2.getOperator().equals(ComparisonExpression.Operator.EQUAL)) {
                        Symbol from = Symbol.from(comparisonExpression2.getLeft());
                        Symbol from2 = Symbol.from(comparisonExpression2.getRight());
                        if (this.correlation.contains(from) && !this.correlation.contains(from2)) {
                            builder.put(from, from2);
                        }
                        if (this.correlation.contains(from2) && !this.correlation.contains(from)) {
                            builder.put(from2, from);
                        }
                    }
                }
            }
            return builder.build();
        }

        private boolean isCorrelated(Expression expression) {
            Stream<Symbol> stream = this.correlation.stream();
            Set<Symbol> extractUnique = SymbolsExtractor.extractUnique(expression);
            extractUnique.getClass();
            return stream.anyMatch((v1) -> {
                return r1.contains(v1);
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/PlanNodeDecorrelator$DecorrelationResult.class */
    public static class DecorrelationResult {
        final PlanNode node;
        final Set<Symbol> symbolsToPropagate;
        final List<Expression> correlatedPredicates;
        final Multimap<Symbol, Symbol> correlatedSymbolsMapping;
        final boolean atMostSingleRow;

        DecorrelationResult(PlanNode planNode, Set<Symbol> set, List<Expression> list, Multimap<Symbol, Symbol> multimap, boolean z) {
            this.node = planNode;
            this.symbolsToPropagate = set;
            this.correlatedPredicates = list;
            this.atMostSingleRow = z;
            this.correlatedSymbolsMapping = multimap;
            Preconditions.checkState(set.containsAll(multimap.values()), "Expected symbols to propagate to contain all constant symbols");
        }

        SymbolMapper getCorrelatedSymbolMapper() {
            return new SymbolMapper((Map) this.correlatedSymbolsMapping.asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return (Symbol) Iterables.getLast((Iterable) entry.getValue());
            })));
        }

        Set<Symbol> getConstantSymbols() {
            return ImmutableSet.copyOf(this.correlatedSymbolsMapping.values());
        }
    }

    public PlanNodeDecorrelator(PlanNodeIdAllocator planNodeIdAllocator, Lookup lookup) {
        this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
    }

    public Optional<DecorrelatedNode> decorrelateFilters(PlanNode planNode, List<Symbol> list) {
        return ((Optional) this.lookup.resolve(planNode).accept(new DecorrelatingVisitor(list), null)).flatMap(decorrelationResult -> {
            return decorrelatedNode(decorrelationResult.correlatedPredicates, decorrelationResult.node, list);
        });
    }

    private Optional<DecorrelatedNode> decorrelatedNode(List<Expression> list, PlanNode planNode, List<Symbol> list2) {
        return containsCorrelation(planNode, list2) ? Optional.empty() : Optional.of(new DecorrelatedNode(list, planNode));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean containsCorrelation(PlanNode planNode, List<Symbol> list) {
        Stream stream = Sets.union(SymbolsExtractor.extractUnique(planNode, this.lookup), SymbolsExtractor.extractOutputSymbols(planNode, this.lookup)).stream();
        list.getClass();
        return stream.anyMatch((v1) -> {
            return r1.contains(v1);
        });
    }
}
