package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.NodeAndMappings;
import io.trino.sql.planner.PlanCopier;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationsToSubqueries.class */
public class MultipleDistinctAggregationsToSubqueries implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(MultipleDistinctAggregationsToSubqueries::isAggregationCandidateForSplittingToSubqueries);
    private final DistinctAggregationStrategyChooser distinctAggregationStrategyChooser;

    public static boolean isAggregationCandidateForSplittingToSubqueries(AggregationNode aggregationNode) {
        return SingleDistinctAggregationToGroupBy.allDistinctAggregates(aggregationNode) && OptimizeMixedDistinctAggregations.hasMultipleDistincts(aggregationNode) && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getHashSymbol().isEmpty();
    }

    public MultipleDistinctAggregationsToSubqueries(TaskCountEstimator taskCountEstimator, Metadata metadata) {
        this.distinctAggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser(taskCountEstimator, metadata);
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        if (!this.distinctAggregationStrategyChooser.shouldSplitToSubqueries(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())) {
            return Rule.Result.empty();
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap(aggregationNode.getAggregations().size());
        for (Map.Entry entry : (List) aggregationNode.getAggregations().entrySet().stream().sorted(Comparator.comparing(entry2 -> {
            return ((Symbol) entry2.getKey()).name();
        })).collect(ImmutableList.toImmutableList())) {
            linkedHashMap.compute(ImmutableSet.copyOf(((AggregationNode.Aggregation) entry.getValue()).getArguments()), (set, map) -> {
                if (map == null) {
                    map = new HashMap();
                }
                map.put((Symbol) entry.getKey(), (AggregationNode.Aggregation) entry.getValue());
                return map;
            });
        }
        AggregationNode aggregationNode2 = null;
        List<Symbol> list = null;
        Assignments.Builder builder = Assignments.builder();
        ImmutableList copyOf = ImmutableList.copyOf(linkedHashMap.values());
        for (int size = copyOf.size() - 1; size > 0; size--) {
            AggregationNode buildSubAggregation = buildSubAggregation(aggregationNode, (Map) copyOf.get(size), builder, context);
            if (aggregationNode2 == null) {
                aggregationNode2 = buildSubAggregation;
                list = buildSubAggregation.getGroupingKeys();
            } else {
                aggregationNode2 = buildJoin(buildSubAggregation, buildSubAggregation.getGroupingKeys(), aggregationNode2, list, context);
            }
        }
        AggregationNode buildSubAggregation2 = buildSubAggregation(aggregationNode, (Map) copyOf.getFirst(), builder, context);
        for (int i = 0; i < buildSubAggregation2.getGroupingKeys().size(); i++) {
            builder.put(aggregationNode.getGroupingKeys().get(i), buildSubAggregation2.getGroupingKeys().get(i).toSymbolReference());
        }
        return Rule.Result.ofPlanNode(new ProjectNode(aggregationNode.getId(), buildJoin(buildSubAggregation2, buildSubAggregation2.getGroupingKeys(), aggregationNode2, list, context), builder.build()));
    }

    private AggregationNode buildSubAggregation(AggregationNode aggregationNode, Map<Symbol, AggregationNode.Aggregation> map, Assignments.Builder builder, Rule.Context context) {
        ImmutableList copyOf = ImmutableList.copyOf(map.keySet());
        NodeAndMappings copyPlan = PlanCopier.copyPlan(AggregationNode.builderFrom(aggregationNode).setAggregations(map).build(), copyOf, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
        AggregationNode aggregationNode2 = (AggregationNode) copyPlan.getNode();
        for (int i = 0; i < copyOf.size(); i++) {
            builder.put((Symbol) copyOf.get(i), copyPlan.getFields().get(i).toSymbolReference());
        }
        return aggregationNode2;
    }

    private JoinNode buildJoin(PlanNode planNode, List<Symbol> list, PlanNode planNode2, List<Symbol> list2, Rule.Context context) {
        Preconditions.checkArgument(list.size() == list2.size());
        return new JoinNode(context.getIdAllocator().getNextId(), JoinType.INNER, planNode, planNode2, (List) IntStream.range(0, list.size()).mapToObj(i -> {
            return new JoinNode.EquiJoinClause((Symbol) list.get(i), (Symbol) list2.get(i));
        }).collect(ImmutableList.toImmutableList()), planNode.getOutputSymbols(), planNode2.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    }
}
