package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.GroupIdNode;
import io.prestosql.sql.planner.plan.MarkDistinctNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.CoalesceExpression;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.IfExpression;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.NullLiteral;
import io.prestosql.sql.tree.QualifiedName;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/OptimizeMixedDistinctAggregations.class */
public class OptimizeMixedDistinctAggregations implements PlanOptimizer {
    private final Metadata metadata;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/OptimizeMixedDistinctAggregations$AggregateInfo.class */
    public static class AggregateInfo {
        private final List<Symbol> groupBySymbols;
        private final Symbol mask;
        private final Map<Symbol, AggregationNode.Aggregation> aggregations;
        private Map<Symbol, Symbol> newNonDistinctAggregateSymbols;
        private Symbol newDistinctAggregateSymbol;
        private boolean foundMarkDistinct;

        public AggregateInfo(List<Symbol> list, Symbol symbol, Map<Symbol, AggregationNode.Aggregation> map) {
            this.groupBySymbols = ImmutableList.copyOf(list);
            this.mask = symbol;
            this.aggregations = ImmutableMap.copyOf(map);
        }

        public List<Symbol> getOriginalNonDistinctAggregateArgs() {
            return (List) this.aggregations.values().stream().filter(aggregation -> {
                return !aggregation.getMask().isPresent();
            }).flatMap(aggregation2 -> {
                return aggregation2.getArguments().stream();
            }).distinct().map(Symbol::from).collect(Collectors.toList());
        }

        public List<Symbol> getOriginalDistinctAggregateArgs() {
            return (List) this.aggregations.values().stream().filter(aggregation -> {
                return aggregation.getMask().isPresent();
            }).flatMap(aggregation2 -> {
                return aggregation2.getArguments().stream();
            }).distinct().map(Symbol::from).collect(Collectors.toList());
        }

        public Symbol getNewDistinctAggregateSymbol() {
            return this.newDistinctAggregateSymbol;
        }

        public void setNewDistinctAggregateSymbol(Symbol symbol) {
            this.newDistinctAggregateSymbol = symbol;
        }

        public Map<Symbol, Symbol> getNewNonDistinctAggregateSymbols() {
            return this.newNonDistinctAggregateSymbols;
        }

        public void setNewNonDistinctAggregateSymbols(Map<Symbol, Symbol> map) {
            this.newNonDistinctAggregateSymbols = map;
        }

        public Symbol getMask() {
            return this.mask;
        }

        public List<Symbol> getGroupBySymbols() {
            return this.groupBySymbols;
        }

        public Map<Symbol, AggregationNode.Aggregation> getAggregations() {
            return this.aggregations;
        }

        public void foundMarkDistinct() {
            this.foundMarkDistinct = true;
        }

        public boolean isFoundMarkDistinct() {
            return this.foundMarkDistinct;
        }
    }

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/OptimizeMixedDistinctAggregations$Optimizer.class */
    private static class Optimizer extends SimplePlanRewriter<Optional<AggregateInfo>> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Metadata metadata;

        private Optimizer(PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator, Metadata metadata) {
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> rewriteContext) {
            List list = (List) aggregationNode.getAggregations().values().stream().map((v0) -> {
                return v0.getMask();
            }).filter((v0) -> {
                return v0.isPresent();
            }).map((v0) -> {
                return v0.get();
            }).collect(ImmutableList.toImmutableList());
            ImmutableSet copyOf = ImmutableSet.copyOf(list);
            if (copyOf.size() != 1 || list.size() == aggregationNode.getAggregations().size()) {
                return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
            }
            if (!aggregationNode.getAggregations().values().stream().map((v0) -> {
                return v0.getFilter();
            }).anyMatch((v0) -> {
                return v0.isPresent();
            }) && !aggregationNode.hasOrderings()) {
                AggregateInfo aggregateInfo = new AggregateInfo(aggregationNode.getGroupingKeys(), (Symbol) Iterables.getOnlyElement(copyOf), aggregationNode.getAggregations());
                if (!checkAllEquatableTypes(aggregateInfo)) {
                    return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
                }
                PlanNode rewrite = rewriteContext.rewrite(aggregationNode.getSource(), Optional.of(aggregateInfo));
                if (!aggregateInfo.isFoundMarkDistinct()) {
                    return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
                }
                ImmutableMap.Builder builder = ImmutableMap.builder();
                ImmutableMap.Builder builder2 = ImmutableMap.builder();
                for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
                    AggregationNode.Aggregation value = entry.getValue();
                    if (value.getMask().isPresent()) {
                        builder.put(entry.getKey(), new AggregationNode.Aggregation(value.getSignature(), ImmutableList.of(aggregateInfo.getNewDistinctAggregateSymbol().toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty()));
                    } else {
                        Symbol symbol = aggregateInfo.getNewNonDistinctAggregateSymbols().get(entry.getKey());
                        QualifiedName of = QualifiedName.of("arbitrary");
                        String name = value.getSignature().getName();
                        AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(getFunctionSignature(of, symbol), ImmutableList.of(symbol.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty());
                        if (name.equals("count") || name.equals("count_if") || name.equals("approx_distinct")) {
                            Symbol newSymbol = this.symbolAllocator.newSymbol("expr", this.symbolAllocator.getTypes().get(entry.getKey()));
                            builder.put(newSymbol, aggregation);
                            builder2.put(newSymbol, entry.getKey());
                        } else {
                            builder.put(entry.getKey(), aggregation);
                        }
                    }
                }
                ImmutableMap build = builder2.build();
                AggregationNode aggregationNode2 = new AggregationNode(this.idAllocator.getNextId(), rewrite, builder.build(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), Optional.empty(), aggregationNode.getGroupIdSymbol());
                if (build.isEmpty()) {
                    return aggregationNode2;
                }
                Assignments.Builder builder3 = Assignments.builder();
                for (Symbol symbol2 : aggregationNode2.getOutputSymbols()) {
                    if (build.containsKey(symbol2)) {
                        builder3.put((Symbol) build.get(symbol2), new CoalesceExpression(symbol2.toSymbolReference(), new Cast(new LongLiteral("0"), "bigint"), new Expression[0]));
                    } else {
                        builder3.putIdentity(symbol2);
                    }
                }
                return new ProjectNode(this.idAllocator.getNextId(), aggregationNode2, builder3.build());
            }
            return rewriteContext.defaultRewrite(aggregationNode, Optional.empty());
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public PlanNode visitMarkDistinct(MarkDistinctNode markDistinctNode, SimplePlanRewriter.RewriteContext<Optional<AggregateInfo>> rewriteContext) {
            Optional<AggregateInfo> optional = rewriteContext.get();
            if (!optional.isPresent() || !optional.get().getMask().equals(markDistinctNode.getMarkerSymbol())) {
                return rewriteContext.defaultRewrite(markDistinctNode, Optional.empty());
            }
            optional.get().foundMarkDistinct();
            PlanNode rewrite = rewriteContext.rewrite(markDistinctNode.getSource(), Optional.empty());
            HashSet hashSet = new HashSet();
            List<Symbol> groupBySymbols = optional.get().getGroupBySymbols();
            List<Symbol> originalNonDistinctAggregateArgs = optional.get().getOriginalNonDistinctAggregateArgs();
            Symbol symbol = (Symbol) Iterables.getOnlyElement(optional.get().getOriginalDistinctAggregateArgs());
            Symbol symbol2 = symbol;
            if (originalNonDistinctAggregateArgs.contains(symbol)) {
                Symbol newSymbol = this.symbolAllocator.newSymbol(symbol.getName(), this.symbolAllocator.getTypes().get(symbol));
                originalNonDistinctAggregateArgs.set(originalNonDistinctAggregateArgs.indexOf(symbol), newSymbol);
                symbol2 = newSymbol;
            }
            hashSet.addAll(groupBySymbols);
            hashSet.addAll(originalNonDistinctAggregateArgs);
            hashSet.add(symbol);
            Symbol newSymbol2 = this.symbolAllocator.newSymbol("group", (Type) BigintType.BIGINT);
            GroupIdNode createGroupIdNode = createGroupIdNode(groupBySymbols, originalNonDistinctAggregateArgs, symbol, symbol2, newSymbol2, hashSet, rewrite);
            HashSet hashSet2 = new HashSet(groupBySymbols);
            hashSet2.add(symbol);
            hashSet2.add(newSymbol2);
            ImmutableMap.Builder<Symbol, Symbol> builder = ImmutableMap.builder();
            return createProjectNode(createNonDistinctAggregation(optional.get(), symbol, symbol2, hashSet2, createGroupIdNode, markDistinctNode, builder), optional.get(), symbol, newSymbol2, groupBySymbols, builder.build());
        }

        private boolean checkAllEquatableTypes(AggregateInfo aggregateInfo) {
            Iterator<Symbol> it = aggregateInfo.getOriginalNonDistinctAggregateArgs().iterator();
            while (it.hasNext()) {
                if (!this.symbolAllocator.getTypes().get(it.next()).isComparable()) {
                    return false;
                }
            }
            return this.symbolAllocator.getTypes().get(aggregateInfo.getMask()).isComparable();
        }

        private ProjectNode createProjectNode(AggregationNode aggregationNode, AggregateInfo aggregateInfo, Symbol symbol, Symbol symbol2, List<Symbol> list, Map<Symbol, Symbol> map) {
            Assignments.Builder builder = Assignments.builder();
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            for (Symbol symbol3 : aggregationNode.getOutputSymbols()) {
                if (symbol.equals(symbol3)) {
                    Symbol newSymbol = this.symbolAllocator.newSymbol("expr", this.symbolAllocator.getTypes().get(symbol3));
                    aggregateInfo.setNewDistinctAggregateSymbol(newSymbol);
                    builder.put(newSymbol, createIfExpression(symbol2.toSymbolReference(), new Cast(new LongLiteral("1"), "bigint"), ComparisonExpression.Operator.EQUAL, symbol3.toSymbolReference(), this.symbolAllocator.getTypes().get(symbol3)));
                } else if (map.containsKey(symbol3)) {
                    Symbol newSymbol2 = this.symbolAllocator.newSymbol("expr", this.symbolAllocator.getTypes().get(symbol3));
                    builder2.put(map.get(symbol3), newSymbol2);
                    builder.put(newSymbol2, createIfExpression(symbol2.toSymbolReference(), new Cast(new LongLiteral("0"), "bigint"), ComparisonExpression.Operator.EQUAL, symbol3.toSymbolReference(), this.symbolAllocator.getTypes().get(symbol3)));
                }
                if (list.contains(symbol3)) {
                    builder.put(symbol3, symbol3.toSymbolReference());
                }
            }
            builder.put(aggregateInfo.getMask(), new NullLiteral());
            aggregateInfo.setNewNonDistinctAggregateSymbols(builder2.build());
            return new ProjectNode(this.idAllocator.getNextId(), aggregationNode, builder.build());
        }

        private GroupIdNode createGroupIdNode(List<Symbol> list, List<Symbol> list2, Symbol symbol, Symbol symbol2, Symbol symbol3, Set<Symbol> set, PlanNode planNode) {
            ArrayList arrayList = new ArrayList();
            HashSet hashSet = new HashSet();
            hashSet.addAll(list);
            hashSet.addAll(list2);
            arrayList.add(ImmutableList.copyOf(hashSet));
            HashSet hashSet2 = new HashSet(list);
            hashSet2.add(symbol);
            arrayList.add(ImmutableList.copyOf(hashSet2));
            return new GroupIdNode(this.idAllocator.getNextId(), planNode, arrayList, (Map) set.stream().collect(Collectors.toMap(symbol4 -> {
                return symbol4;
            }, symbol5 -> {
                return symbol5.equals(symbol2) ? symbol : symbol5;
            })), ImmutableList.of(), symbol3);
        }

        private AggregationNode createNonDistinctAggregation(AggregateInfo aggregateInfo, Symbol symbol, Symbol symbol2, Set<Symbol> set, GroupIdNode groupIdNode, MarkDistinctNode markDistinctNode, ImmutableMap.Builder<Symbol, Symbol> builder) {
            ImmutableMap.Builder builder2 = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregateInfo.getAggregations().entrySet()) {
                AggregationNode.Aggregation value = entry.getValue();
                if (!value.getMask().isPresent()) {
                    Symbol newSymbol = this.symbolAllocator.newSymbol((Expression) entry.getKey().toSymbolReference(), this.symbolAllocator.getTypes().get(entry.getKey()));
                    builder.put(newSymbol, entry.getKey());
                    if (!symbol2.equals(symbol) && value.getArguments().contains(symbol.toSymbolReference())) {
                        ImmutableList.Builder builder3 = ImmutableList.builder();
                        for (Expression expression : value.getArguments()) {
                            if (symbol.toSymbolReference().equals(expression)) {
                                builder3.add(symbol2.toSymbolReference());
                            } else {
                                builder3.add(expression);
                            }
                        }
                        value = new AggregationNode.Aggregation(value.getSignature(), builder3.build(), false, Optional.empty(), Optional.empty(), Optional.empty());
                    }
                    builder2.put(newSymbol, value);
                }
            }
            return new AggregationNode(this.idAllocator.getNextId(), groupIdNode, builder2.build(), AggregationNode.singleGroupingSet(ImmutableList.copyOf(set)), ImmutableList.of(), AggregationNode.Step.SINGLE, markDistinctNode.getHashSymbol(), Optional.empty());
        }

        private Signature getFunctionSignature(QualifiedName qualifiedName, Symbol symbol) {
            return this.metadata.resolveFunction(qualifiedName, TypeSignatureProvider.fromTypes(this.symbolAllocator.getTypes().get(symbol)));
        }

        private static IfExpression createIfExpression(Expression expression, Expression expression2, ComparisonExpression.Operator operator, Expression expression3, Type type) {
            return new IfExpression(new ComparisonExpression(operator, expression, expression2), expression3, new Cast(new NullLiteral(), type.getTypeSignature().toString()));
        }
    }

    public OptimizeMixedDistinctAggregations(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override // io.prestosql.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, TypeProvider typeProvider, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, WarningCollector warningCollector) {
        return SystemSessionProperties.isOptimizeDistinctAggregationEnabled(session) ? SimplePlanRewriter.rewriteWith(new Optimizer(planNodeIdAllocator, symbolAllocator, this.metadata), planNode, Optional.empty()) : planNode;
    }
}
