package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.SystemSessionProperties;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.FunctionResolver;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.QualifiedName;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/OptimizeMixedDistinctAggregations.class */
public class OptimizeMixedDistinctAggregations implements Rule<AggregationNode> {
    private static final CatalogSchemaFunctionName COUNT_NAME = GlobalFunctionCatalog.builtinFunctionName("count");
    private static final CatalogSchemaFunctionName COUNT_IF_NAME = GlobalFunctionCatalog.builtinFunctionName("count_if");
    private static final CatalogSchemaFunctionName APPROX_DISTINCT_NAME = GlobalFunctionCatalog.builtinFunctionName("approx_distinct");
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(OptimizeMixedDistinctAggregations::canUsePreAggregate);
    private final FunctionResolver functionResolver;
    private final DistinctAggregationStrategyChooser distinctAggregationStrategyChooser;

    public static boolean canUsePreAggregate(AggregationNode aggregationNode) {
        return (hasMultipleDistincts(aggregationNode) || hasMixedDistinctAndNonDistincts(aggregationNode)) && allDistinctAggregationsHaveSingleArgument(aggregationNode) && noFilters(aggregationNode) && noMasks(aggregationNode) && !aggregationNode.hasOrderings() && aggregationNode.getStep().equals(AggregationNode.Step.SINGLE);
    }

    public static boolean hasMultipleDistincts(AggregationNode aggregationNode) {
        return distinctAggregationsUniqueArgumentCount(aggregationNode) > 1;
    }

    public static long distinctAggregationsUniqueArgumentCount(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().filter((v0) -> {
            return v0.isDistinct();
        }).map((v0) -> {
            return v0.getArguments();
        }).map((v1) -> {
            return new HashSet(v1);
        }).distinct().count();
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode) {
        long count = aggregationNode.getAggregations().values().stream().filter((v0) -> {
            return v0.isDistinct();
        }).count();
        return count > 0 && count < ((long) aggregationNode.getAggregations().size());
    }

    private static boolean allDistinctAggregationsHaveSingleArgument(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().filter((v0) -> {
            return v0.isDistinct();
        }).allMatch(aggregation -> {
            return aggregation.getArguments().size() == 1;
        });
    }

    private static boolean noFilters(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> {
            return aggregation.getFilter().isPresent();
        });
    }

    private static boolean noMasks(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> {
            return aggregation.getMask().isPresent();
        });
    }

    public OptimizeMixedDistinctAggregations(PlannerContext plannerContext, TaskCountEstimator taskCountEstimator) {
        this.functionResolver = plannerContext.getFunctionResolver();
        this.distinctAggregationStrategyChooser = DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser(taskCountEstimator, plannerContext.getMetadata());
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        OptimizerConfig.DistinctAggregationsStrategy distinctAggregationsStrategy = SystemSessionProperties.distinctAggregationsStrategy(context.getSession());
        if (!distinctAggregationsStrategy.equals(OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE) && (!distinctAggregationsStrategy.equals(OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC) || !this.distinctAggregationStrategyChooser.shouldUsePreAggregate(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()))) {
            return Rule.Result.empty();
        }
        SymbolAllocator symbolAllocator = context.getSymbolAllocator();
        Set<Symbol> set = (Set) aggregationNode.getAggregations().values().stream().filter((v0) -> {
            return v0.isDistinct();
        }).flatMap(aggregation -> {
            return aggregation.getArguments().stream();
        }).map(Symbol::from).collect(ImmutableSet.toImmutableSet());
        boolean anyMatch = aggregationNode.getAggregations().values().stream().anyMatch(aggregation2 -> {
            return !aggregation2.isDistinct();
        });
        ImmutableMap.Builder builder = ImmutableMap.builder();
        int i = anyMatch ? 1 : 0;
        Iterator it = set.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            builder.put((Symbol) it.next(), Integer.valueOf(i2));
        }
        ImmutableMap buildOrThrow = builder.buildOrThrow();
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        for (Symbol symbol : set) {
            builder2.put(symbol, symbol);
        }
        for (Symbol symbol2 : aggregationNode.getGroupingKeys()) {
            builder2.put(symbol2, symbol2);
        }
        Symbol newSymbol = symbolAllocator.newSymbol("group", BigintType.BIGINT);
        Assignments.Builder builder3 = Assignments.builder();
        Symbol newSymbol2 = symbolAllocator.newSymbol("non-distinct-gid-filter", BooleanType.BOOLEAN);
        if (anyMatch) {
            builder3.put(newSymbol2, new Comparison(Comparison.Operator.EQUAL, newSymbol.toSymbolReference(), new Constant(BigintType.BIGINT, 0L)));
        }
        ImmutableMap.Builder builder4 = ImmutableMap.builder();
        HashMap hashMap = new HashMap();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            AggregationNode.Aggregation value = entry.getValue();
            if (value.isDistinct()) {
                Integer num = (Integer) buildOrThrow.get(Symbol.from((Expression) value.getArguments().getFirst()));
                builder4.put(entry.getKey(), new AggregationNode.Aggregation(value.getResolvedFunction(), value.getArguments(), false, Optional.of((Symbol) hashMap.computeIfAbsent(num, num2 -> {
                    Symbol newSymbol3 = symbolAllocator.newSymbol("gid-filter-" + num, BooleanType.BOOLEAN);
                    builder3.put(newSymbol3, new Comparison(Comparison.Operator.EQUAL, newSymbol.toSymbolReference(), new Constant(BigintType.BIGINT, Long.valueOf(num.intValue()))));
                    return newSymbol3;
                })), Optional.empty(), Optional.empty()));
            }
        }
        ImmutableMap.Builder builder5 = ImmutableMap.builder();
        ImmutableMap.Builder builder6 = ImmutableMap.builder();
        ImmutableSet.Builder builder7 = ImmutableSet.builder();
        HashMap hashMap2 = new HashMap();
        if (anyMatch) {
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry2 : aggregationNode.getAggregations().entrySet()) {
                AggregationNode.Aggregation value2 = entry2.getValue();
                if (!value2.isDistinct()) {
                    Symbol key = entry2.getKey();
                    ImmutableList.Builder builder8 = ImmutableList.builder();
                    Iterator<Expression> it2 = value2.getArguments().iterator();
                    while (it2.hasNext()) {
                        Symbol from = Symbol.from(it2.next());
                        Symbol symbol3 = from;
                        if (set.contains(from)) {
                            symbol3 = (Symbol) hashMap2.computeIfAbsent(from, symbol4 -> {
                                return symbolAllocator.newSymbol("gid-non-distinct", symbol4.type());
                            });
                        }
                        builder2.put(symbol3, from);
                        builder8.add(symbol3.toSymbolReference());
                        builder7.add(symbol3);
                    }
                    Symbol newSymbol3 = symbolAllocator.newSymbol("inner", key.type());
                    builder5.put(newSymbol3, new AggregationNode.Aggregation(value2.getResolvedFunction(), builder8.build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
                    AggregationNode.Aggregation aggregation3 = new AggregationNode.Aggregation(this.functionResolver.resolveFunction(context.getSession(), QualifiedName.of("arbitrary"), TypeSignatureProvider.fromTypes(value2.getResolvedFunction().signature().getReturnType()), new AllowAllAccessControl()), ImmutableList.of(newSymbol3.toSymbolReference()), false, Optional.of(newSymbol2), Optional.empty(), Optional.empty());
                    Symbol symbol5 = key;
                    CatalogSchemaFunctionName name = value2.getResolvedFunction().signature().getName();
                    if (name.equals(COUNT_NAME) || name.equals(COUNT_IF_NAME) || name.equals(APPROX_DISTINCT_NAME)) {
                        Symbol newSymbol4 = symbolAllocator.newSymbol("coalesce_expr", key.type());
                        symbol5 = newSymbol4;
                        builder6.put(newSymbol4, key);
                    }
                    builder4.put(symbol5, aggregation3);
                }
            }
        }
        AggregationNode aggregationNode2 = new AggregationNode(context.getIdAllocator().getNextId(), new GroupIdNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), createGroups(aggregationNode.getGroupingKeys(), builder7.build(), anyMatch, buildOrThrow), builder2.buildKeepingLast(), ImmutableList.of(), newSymbol), builder5.buildOrThrow(), AggregationNode.singleGroupingSet(ImmutableList.copyOf(ImmutableSet.builder().addAll(aggregationNode.getGroupingKeys()).addAll(set).add(newSymbol).build())), ImmutableList.of(), AggregationNode.Step.SINGLE, aggregationNode.getHashSymbol(), Optional.empty());
        builder3.putIdentities(aggregationNode2.getOutputSymbols());
        AggregationNode aggregationNode3 = new AggregationNode(context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode2, builder3.build()), builder4.buildOrThrow(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), Optional.empty(), aggregationNode.getGroupIdSymbol());
        ImmutableMap buildOrThrow2 = builder6.buildOrThrow();
        if (buildOrThrow2.isEmpty()) {
            return Rule.Result.ofPlanNode(aggregationNode3);
        }
        Assignments.Builder builder9 = Assignments.builder();
        for (Symbol symbol6 : aggregationNode3.getOutputSymbols()) {
            if (buildOrThrow2.containsKey(symbol6)) {
                builder9.put((Symbol) buildOrThrow2.get(symbol6), new Coalesce(symbol6.toSymbolReference(), new Constant(BigintType.BIGINT, 0L), new Expression[0]));
            } else {
                builder9.putIdentity(symbol6);
            }
        }
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode3, builder9.build()));
    }

    private static List<List<Symbol>> createGroups(List<Symbol> list, Set<Symbol> set, boolean z, Map<Symbol, Integer> map) {
        ImmutableList.Builder builder = ImmutableList.builder();
        if (z) {
            builder.add(ImmutableList.copyOf(ImmutableSet.builder().addAll(list).addAll(set).build()));
        }
        map.entrySet().stream().sorted(Map.Entry.comparingByValue()).forEach(entry -> {
            builder.add(ImmutableList.copyOf(ImmutableSet.builder().addAll(list).add((Symbol) entry.getKey()).build()));
        });
        return builder.build();
    }
}
