/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.ComparisonStatsCalculator;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.PlanNodeStatsEstimateMath;
import io.trino.cost.ScalarStatsCalculator;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.spi.statistics.StatsUtil;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.DynamicFilters;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.optimizer.IrExpressionOptimizer;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.util.DisjointSet;
import jakarta.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.stream.IntStream;

public class FilterStatsCalculator {
    static final double UNKNOWN_FILTER_COEFFICIENT = 0.9;
    private final PlannerContext plannerContext;
    private final ScalarStatsCalculator scalarStatsCalculator;
    private final StatsNormalizer normalizer;

    @Inject
    public FilterStatsCalculator(PlannerContext plannerContext, ScalarStatsCalculator scalarStatsCalculator, StatsNormalizer normalizer) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.scalarStatsCalculator = Objects.requireNonNull(scalarStatsCalculator, "scalarStatsCalculator is null");
        this.normalizer = Objects.requireNonNull(normalizer, "normalizer is null");
    }

    public PlanNodeStatsEstimate filterStats(PlanNodeStatsEstimate statsEstimate, Expression predicate, Session session) {
        Expression simplifiedExpression = this.simplifyExpression(session, predicate);
        return (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(statsEstimate, session).process(simplifiedExpression);
    }

    private Expression simplifyExpression(Session session, Expression predicate) {
        Constant constant;
        Expression value = IrExpressionOptimizer.newOptimizer(this.plannerContext).process(predicate, session, (Map<Symbol, Expression>)ImmutableMap.of()).orElse(predicate);
        if (value instanceof Constant && (constant = (Constant)value).value() == null) {
            value = Booleans.FALSE;
        }
        return value;
    }

    private static List<List<Expression>> extractCorrelatedGroups(List<Expression> terms, double filterConjunctionIndependenceFactor) {
        if (filterConjunctionIndependenceFactor == 1.0) {
            return ImmutableList.of(terms);
        }
        ArrayListMultimap expressionUniqueSymbols = ArrayListMultimap.create();
        terms.forEach(arg_0 -> FilterStatsCalculator.lambda$extractCorrelatedGroups$0((ListMultimap)expressionUniqueSymbols, arg_0));
        DisjointSet<Symbol> symbolsPartitioner = new DisjointSet<Symbol>();
        for (Expression term : terms) {
            List expressionSymbols = expressionUniqueSymbols.get((Object)term);
            if (expressionSymbols.isEmpty()) continue;
            symbolsPartitioner.find((Symbol)expressionSymbols.get(0));
            for (int i = 1; i < expressionSymbols.size(); ++i) {
                symbolsPartitioner.findAndUnion((Symbol)expressionSymbols.get(0), (Symbol)expressionSymbols.get(i));
            }
        }
        ImmutableList symbolPartitions = ImmutableList.copyOf(symbolsPartitioner.getEquivalentClasses());
        Preconditions.checkState((symbolPartitions.size() <= terms.size() ? 1 : 0) != 0, (Object)"symbolPartitions size exceeds number of expressions");
        ArrayListMultimap expressionPartitions = ArrayListMultimap.create();
        for (Expression term : terms) {
            int expressionPartitionId;
            List expressionSymbols = expressionUniqueSymbols.get((Object)term);
            if (expressionSymbols.isEmpty()) {
                expressionPartitionId = symbolPartitions.size();
            } else {
                Symbol symbol = (Symbol)expressionSymbols.get(0);
                expressionPartitionId = IntStream.range(0, symbolPartitions.size()).filter(arg_0 -> FilterStatsCalculator.lambda$extractCorrelatedGroups$1((List)symbolPartitions, symbol, arg_0)).findFirst().orElseThrow();
            }
            expressionPartitions.put((Object)expressionPartitionId, (Object)term);
        }
        return (List)expressionPartitions.keySet().stream().map(arg_0 -> ((ListMultimap)expressionPartitions).get(arg_0)).collect(ImmutableList.toImmutableList());
    }

    private static /* synthetic */ boolean lambda$extractCorrelatedGroups$1(List symbolPartitions, Symbol symbol, int partition) {
        return ((Set)symbolPartitions.get(partition)).contains(symbol);
    }

    private static /* synthetic */ void lambda$extractCorrelatedGroups$0(ListMultimap expressionUniqueSymbols, Expression expression) {
        expressionUniqueSymbols.putAll((Object)expression, SymbolsExtractor.extractUnique(expression));
    }

    private class FilterExpressionStatsCalculatingVisitor
    extends IrVisitor<PlanNodeStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;

        FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, Session session) {
            this.input = input;
            this.session = session;
        }

        @Override
        public PlanNodeStatsEstimate process(Expression node, @Nullable Void context) {
            PlanNodeStatsEstimate output = this.input.getOutputRowCount() == 0.0 || this.input.isOutputRowCountUnknown() ? this.input : (PlanNodeStatsEstimate)super.process(node, context);
            return FilterStatsCalculator.this.normalizer.normalize(output);
        }

        @Override
        protected PlanNodeStatsEstimate visitExpression(Expression node, Void context) {
            return PlanNodeStatsEstimate.unknown();
        }

        @Override
        protected PlanNodeStatsEstimate visitLogical(Logical node, Void context) {
            return switch (node.operator()) {
                default -> throw new MatchException(null, null);
                case Logical.Operator.AND -> this.estimateLogicalAnd(node.terms());
                case Logical.Operator.OR -> this.estimateLogicalOr(node.terms());
            };
        }

        private PlanNodeStatsEstimate estimateLogicalAnd(List<Expression> terms) {
            double filterConjunctionIndependenceFactor = SystemSessionProperties.getFilterConjunctionIndependenceFactor(this.session);
            List<PlanNodeStatsEstimate> estimates = this.estimateCorrelatedExpressions(terms, filterConjunctionIndependenceFactor);
            double outputRowCount = PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount(this.input, estimates, filterConjunctionIndependenceFactor);
            if (Double.isNaN(outputRowCount)) {
                return PlanNodeStatsEstimate.unknown();
            }
            return FilterStatsCalculator.this.normalizer.normalize(new PlanNodeStatsEstimate(outputRowCount, PlanNodeStatsEstimateMath.intersectCorrelatedStats(estimates)));
        }

        private List<PlanNodeStatsEstimate> estimateCorrelatedExpressions(List<Expression> terms, double filterConjunctionIndependenceFactor) {
            List<List<Expression>> extractedCorrelatedGroups = FilterStatsCalculator.extractCorrelatedGroups(terms, filterConjunctionIndependenceFactor);
            ImmutableList.Builder estimatesBuilder = ImmutableList.builderWithExpectedSize((int)extractedCorrelatedGroups.size());
            boolean hasUnestimatedTerm = false;
            for (List<Expression> correlatedExpressions : extractedCorrelatedGroups) {
                PlanNodeStatsEstimate combinedEstimate = PlanNodeStatsEstimate.unknown();
                for (Expression expression : correlatedExpressions) {
                    PlanNodeStatsEstimate estimate = combinedEstimate.isOutputRowCountUnknown() ? (PlanNodeStatsEstimate)this.process(expression) : (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(combinedEstimate, this.session).process(expression);
                    if (estimate.isOutputRowCountUnknown()) {
                        hasUnestimatedTerm = true;
                        continue;
                    }
                    combinedEstimate = estimate;
                }
                estimatesBuilder.add((Object)combinedEstimate);
            }
            if (hasUnestimatedTerm) {
                estimatesBuilder.add((Object)PlanNodeStatsEstimate.unknown());
            }
            return estimatesBuilder.build();
        }

        private PlanNodeStatsEstimate estimateLogicalOr(List<Expression> terms) {
            PlanNodeStatsEstimate previous = (PlanNodeStatsEstimate)this.process(terms.get(0));
            if (previous.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            for (int i = 1; i < terms.size(); ++i) {
                PlanNodeStatsEstimate current = (PlanNodeStatsEstimate)this.process(terms.get(i));
                if (current.isOutputRowCountUnknown()) {
                    return PlanNodeStatsEstimate.unknown();
                }
                PlanNodeStatsEstimate andEstimate = (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(previous, this.session).process(terms.get(i));
                if (andEstimate.isOutputRowCountUnknown()) {
                    return PlanNodeStatsEstimate.unknown();
                }
                previous = PlanNodeStatsEstimateMath.capStats(PlanNodeStatsEstimateMath.subtractSubsetStats(PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(previous, current), andEstimate), this.input);
            }
            return previous;
        }

        @Override
        protected PlanNodeStatsEstimate visitConstant(Constant node, Void context) {
            if (node.type().equals((Object)BooleanType.BOOLEAN) && node.value() != null) {
                if (((Boolean)node.value()).booleanValue()) {
                    return this.input;
                }
                PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
                result.setOutputRowCount(0.0);
                this.input.getSymbolsWithKnownStatistics().forEach(symbol -> result.addSymbolStatistics((Symbol)symbol, SymbolStatsEstimate.zero()));
                return result.build();
            }
            return (PlanNodeStatsEstimate)super.visitConstant(node, context);
        }

        @Override
        protected PlanNodeStatsEstimate visitIsNull(IsNull node, Void context) {
            if (node.value() instanceof Reference) {
                Symbol symbol = Symbol.from(node.value());
                SymbolStatsEstimate symbolStats = this.input.getSymbolStatistics(symbol);
                PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
                result.setOutputRowCount(this.input.getOutputRowCount() * symbolStats.getNullsFraction());
                result.addSymbolStatistics(symbol, SymbolStatsEstimate.builder().setNullsFraction(1.0).setLowValue(Double.NaN).setHighValue(Double.NaN).setDistinctValuesCount(0.0).build());
                return result.build();
            }
            return PlanNodeStatsEstimate.unknown();
        }

        @Override
        protected PlanNodeStatsEstimate visitBetween(Between node, Void context) {
            SymbolStatsEstimate valueStats = this.getExpressionStats(node.value());
            if (valueStats.isUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            if (!this.getExpressionStats(node.min()).isSingleValue()) {
                return PlanNodeStatsEstimate.unknown();
            }
            if (!this.getExpressionStats(node.max()).isSingleValue()) {
                return PlanNodeStatsEstimate.unknown();
            }
            Comparison lowerBound = new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, node.value(), node.min());
            Comparison upperBound = new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, node.value(), node.max());
            Expression transformed = Double.isInfinite(valueStats.getLowValue()) ? IrUtils.and(lowerBound, upperBound) : IrUtils.and(upperBound, lowerBound);
            return (PlanNodeStatsEstimate)this.process(transformed);
        }

        @Override
        protected PlanNodeStatsEstimate visitIn(In node, Void context) {
            ImmutableList equalityEstimates = (ImmutableList)node.valueList().stream().map(inValue -> (PlanNodeStatsEstimate)this.process(new Comparison(Comparison.Operator.EQUAL, node.value(), (Expression)inValue))).collect(ImmutableList.toImmutableList());
            if (equalityEstimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown)) {
                return PlanNodeStatsEstimate.unknown();
            }
            PlanNodeStatsEstimate inEstimate = equalityEstimates.stream().reduce(PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues).orElse(PlanNodeStatsEstimate.unknown());
            if (inEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            SymbolStatsEstimate valueStats = this.getExpressionStats(node.value());
            if (valueStats.isUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            double notNullValuesBeforeIn = this.input.getOutputRowCount() * (1.0 - valueStats.getNullsFraction());
            PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
            result.setOutputRowCount(Double.min(inEstimate.getOutputRowCount(), notNullValuesBeforeIn));
            if (node.value() instanceof Reference) {
                Symbol valueSymbol = Symbol.from(node.value());
                SymbolStatsEstimate newSymbolStats = inEstimate.getSymbolStatistics(valueSymbol).mapDistinctValuesCount(newDistinctValuesCount -> Double.min(newDistinctValuesCount, valueStats.getDistinctValuesCount()));
                result.addSymbolStatistics(valueSymbol, newSymbolStats);
            }
            return result.build();
        }

        @Override
        protected PlanNodeStatsEstimate visitComparison(Comparison node, Void context) {
            Optional<Symbol> leftSymbol;
            Comparison.Operator operator = node.operator();
            Expression left = node.left();
            Expression right = node.right();
            Preconditions.checkArgument((!(left instanceof Constant) || !(right instanceof Constant) ? 1 : 0) != 0, (Object)"Literal-to-literal not supported here, should be eliminated earlier");
            if (!(left instanceof Reference) && right instanceof Reference) {
                return (PlanNodeStatsEstimate)this.process(new Comparison(operator.flip(), right, left));
            }
            if (left instanceof Constant) {
                return (PlanNodeStatsEstimate)this.process(new Comparison(operator.flip(), right, left));
            }
            if (left instanceof Reference && left.equals(right)) {
                return (PlanNodeStatsEstimate)this.process(IrExpressions.not(FilterStatsCalculator.this.plannerContext.getMetadata(), new IsNull(left)));
            }
            SymbolStatsEstimate leftStats = this.getExpressionStats(left);
            Optional<Symbol> optional = leftSymbol = left instanceof Reference ? Optional.of(Symbol.from(left)) : Optional.empty();
            if (right instanceof Constant) {
                Constant constant = (Constant)right;
                Type type = right.type();
                Object literalValue = constant.value();
                if (literalValue == null) {
                    return this.input.mapOutputRowCount(rowCountEstimate -> 0.0);
                }
                OptionalDouble literal = StatsUtil.toStatsRepresentation((Type)type, (Object)literalValue);
                return ComparisonStatsCalculator.estimateExpressionToLiteralComparison(this.input, leftStats, leftSymbol, literal, operator);
            }
            SymbolStatsEstimate rightStats = this.getExpressionStats(right);
            if (rightStats.isSingleValue()) {
                OptionalDouble value = Double.isNaN(rightStats.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(rightStats.getLowValue());
                return ComparisonStatsCalculator.estimateExpressionToLiteralComparison(this.input, leftStats, leftSymbol, value, operator);
            }
            Optional<Symbol> rightSymbol = right instanceof Reference ? Optional.of(Symbol.from(right)) : Optional.empty();
            return ComparisonStatsCalculator.estimateExpressionToExpressionComparison(this.input, leftStats, leftSymbol, rightStats, rightSymbol, operator);
        }

        @Override
        protected PlanNodeStatsEstimate visitCall(Call node, Void context) {
            if (DynamicFilters.isDynamicFilter(node)) {
                return this.process((Expression)Booleans.TRUE, context);
            }
            if (node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName("$not"))) {
                Expression argument = node.arguments().getFirst();
                if (argument instanceof IsNull) {
                    IsNull inner = (IsNull)argument;
                    if (inner.value() instanceof Reference) {
                        Symbol symbol = Symbol.from(inner.value());
                        SymbolStatsEstimate symbolStats = this.input.getSymbolStatistics(symbol);
                        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
                        result.setOutputRowCount(this.input.getOutputRowCount() * (1.0 - symbolStats.getNullsFraction()));
                        result.addSymbolStatistics(symbol, symbolStats.mapNullsFraction(x -> 0.0));
                        return result.build();
                    }
                    return PlanNodeStatsEstimate.unknown();
                }
                return PlanNodeStatsEstimateMath.subtractSubsetStats(this.input, (PlanNodeStatsEstimate)this.process(argument));
            }
            return PlanNodeStatsEstimate.unknown();
        }

        private SymbolStatsEstimate getExpressionStats(Expression expression) {
            if (expression instanceof Reference) {
                Symbol symbol = Symbol.from(expression);
                return Objects.requireNonNull(this.input.getSymbolStatistics(symbol), () -> String.format("No statistics for symbol %s", symbol));
            }
            return FilterStatsCalculator.this.scalarStatsCalculator.calculate(expression, this.input, this.session);
        }
    }
}

