package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.prestosql.SystemSessionProperties;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.MarkDistinctNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.tree.FunctionCall;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.class */
public class MultipleDistinctAggregationToMarkDistinct implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(Predicates.and(MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask, Predicates.or(MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts, MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));

    private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> {
            return aggregation.getCall().isDistinct() && (aggregation.getCall().getFilter().isPresent() || aggregation.getMask().isPresent());
        });
    }

    private static boolean hasMultipleDistincts(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().filter(aggregation -> {
            return aggregation.getCall().isDistinct();
        }).map((v0) -> {
            return v0.getCall();
        }).map((v0) -> {
            return v0.getArguments();
        }).map((v1) -> {
            return new HashSet(v1);
        }).distinct().count() > 1;
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode) {
        long count = aggregationNode.getAggregations().values().stream().map((v0) -> {
            return v0.getCall();
        }).filter((v0) -> {
            return v0.isDistinct();
        }).count();
        return count > 0 && count < ((long) aggregationNode.getAggregations().size());
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
            return Rule.Result.empty();
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PlanNode source = aggregationNode.getSource();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            AggregationNode.Aggregation value = entry.getValue();
            FunctionCall call = value.getCall();
            if (!call.isDistinct() || call.getFilter().isPresent() || value.getMask().isPresent()) {
                hashMap2.put(entry.getKey(), value);
            } else {
                Set set = (Set) call.getArguments().stream().map(Symbol::from).collect(Collectors.toSet());
                Symbol symbol = (Symbol) hashMap.get(set);
                if (symbol == null) {
                    symbol = context.getSymbolAllocator().newSymbol(((Symbol) Iterables.getLast(set)).getName(), (Type) BooleanType.BOOLEAN, "distinct");
                    hashMap.put(set, symbol);
                    ImmutableSet.Builder addAll = ImmutableSet.builder().addAll(aggregationNode.getGroupingKeys()).addAll(set);
                    Optional<Symbol> groupIdSymbol = aggregationNode.getGroupIdSymbol();
                    addAll.getClass();
                    groupIdSymbol.ifPresent((v1) -> {
                        r1.add(v1);
                    });
                    source = new MarkDistinctNode(context.getIdAllocator().getNextId(), source, symbol, ImmutableList.copyOf(addAll.build()), Optional.empty());
                }
                hashMap2.put(entry.getKey(), new AggregationNode.Aggregation(new FunctionCall(call.getName(), call.getWindow(), call.getFilter(), call.getOrderBy(), false, call.getArguments()), value.getSignature(), Optional.of(symbol)));
            }
        }
        return Rule.Result.ofPlanNode(new AggregationNode(aggregationNode.getId(), source, hashMap2, aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol()));
    }
}
