package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.analyzer.FeaturesConfig;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.joins.JoinGraph;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.Expression;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.PriorityQueue;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.class */
public class EliminateCrossJoins implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        FeaturesConfig.JoinReorderingStrategy joinReorderingStrategy = SystemSessionProperties.getJoinReorderingStrategy(session);
        return joinReorderingStrategy == FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS || joinReorderingStrategy == FeaturesConfig.JoinReorderingStrategy.AUTOMATIC;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        JoinGraph buildShallowFrom = JoinGraph.buildShallowFrom(joinNode, context.getLookup());
        if (buildShallowFrom.size() < 3) {
            return Rule.Result.empty();
        }
        List<Integer> joinOrder = getJoinOrder(buildShallowFrom);
        return isOriginalOrder(joinOrder) ? Rule.Result.empty() : Rule.Result.ofPlanNode(buildJoinTree(joinNode.getOutputSymbols(), buildShallowFrom, joinOrder, context.getIdAllocator()));
    }

    public static boolean isOriginalOrder(List<Integer> list) {
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i).intValue() != i) {
                return false;
            }
        }
        return true;
    }

    public static List<Integer> getJoinOrder(JoinGraph joinGraph) {
        ImmutableList.Builder builder = ImmutableList.builder();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < joinGraph.size(); i++) {
            hashMap.put(joinGraph.getNode(i).getId(), Integer.valueOf(i));
        }
        PriorityQueue priorityQueue = new PriorityQueue(joinGraph.size(), Comparator.comparing(planNode -> {
            return (Integer) hashMap.get(planNode.getId());
        }));
        HashSet hashSet = new HashSet();
        priorityQueue.add(joinGraph.getNode(0));
        while (!priorityQueue.isEmpty()) {
            PlanNode planNode2 = (PlanNode) priorityQueue.poll();
            if (!hashSet.contains(planNode2)) {
                hashSet.add(planNode2);
                builder.add(planNode2);
                Iterator<JoinGraph.Edge> it = joinGraph.getEdges(planNode2).iterator();
                while (it.hasNext()) {
                    priorityQueue.add(it.next().getTargetNode());
                }
            }
            if (priorityQueue.isEmpty() && hashSet.size() < joinGraph.size()) {
                Optional<PlanNode> findFirst = joinGraph.getNodes().stream().filter(planNode3 -> {
                    return !hashSet.contains(planNode3);
                }).findFirst();
                if (findFirst.isPresent()) {
                    priorityQueue.add(findFirst.get());
                }
            }
        }
        Preconditions.checkState(hashSet.size() == joinGraph.size());
        return (List) builder.build().stream().map(planNode4 -> {
            return (Integer) hashMap.get(planNode4.getId());
        }).collect(ImmutableList.toImmutableList());
    }

    public static PlanNode buildJoinTree(List<Symbol> list, JoinGraph joinGraph, List<Integer> list2, PlanNodeIdAllocator planNodeIdAllocator) {
        Objects.requireNonNull(list, "expectedOutputSymbols is null");
        Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        Objects.requireNonNull(joinGraph, "graph is null");
        ImmutableList copyOf = ImmutableList.copyOf((Collection) Objects.requireNonNull(list2, "joinOrder is null"));
        Preconditions.checkArgument(copyOf.size() >= 2);
        PlanNode node = joinGraph.getNode(((Integer) copyOf.get(0)).intValue());
        HashSet hashSet = new HashSet();
        hashSet.add(node.getId());
        for (int i = 1; i < copyOf.size(); i++) {
            PlanNode node2 = joinGraph.getNode(((Integer) copyOf.get(i)).intValue());
            hashSet.add(node2.getId());
            ImmutableList.Builder builder = ImmutableList.builder();
            for (JoinGraph.Edge edge : joinGraph.getEdges(node2)) {
                if (hashSet.contains(edge.getTargetNode().getId())) {
                    builder.add(new JoinNode.EquiJoinClause(edge.getTargetSymbol(), edge.getSourceSymbol()));
                }
            }
            node = new JoinNode(planNodeIdAllocator.getNextId(), JoinNode.Type.INNER, node, node2, builder.build(), ImmutableList.builder().addAll(node.getOutputSymbols()).addAll(node2.getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
        }
        Iterator<Expression> it = joinGraph.getFilters().iterator();
        while (it.hasNext()) {
            node = new FilterNode(planNodeIdAllocator.getNextId(), node, it.next());
        }
        if (joinGraph.getAssignments().isPresent()) {
            node = new ProjectNode(planNodeIdAllocator.getNextId(), node, Assignments.copyOf(joinGraph.getAssignments().get()));
        }
        return Util.restrictOutputs(planNodeIdAllocator, node, ImmutableSet.copyOf(list)).orElse(node);
    }
}
