package io.trino.cost;

import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.OperatorType;
import io.trino.spi.statistics.StatsUtil;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.optimizer.IrExpressionOptimizer;
import io.trino.sql.planner.Symbol;
import io.trino.util.MoreMath;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.SwitchBootstraps;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;

/* loaded from: input_file:io/trino/cost/ScalarStatsCalculator.class */
public class ScalarStatsCalculator {
    private final PlannerContext plannerContext;

    /* loaded from: input_file:io/trino/cost/ScalarStatsCalculator$Visitor.class */
    private class Visitor extends IrVisitor<SymbolStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;

        Visitor(PlanNodeStatsEstimate planNodeStatsEstimate, Session session) {
            this.input = planNodeStatsEstimate;
            this.session = session;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public SymbolStatsEstimate visitExpression(Expression expression, Void r4) {
            return SymbolStatsEstimate.unknown();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public SymbolStatsEstimate visitReference(Reference reference, Void r5) {
            return this.input.getSymbolStatistics(Symbol.from(reference));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public SymbolStatsEstimate visitConstant(Constant constant, Void r6) {
            Type type = constant.type();
            Object value = constant.value();
            if (value == null) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            OptionalDouble statsRepresentation = StatsUtil.toStatsRepresentation(type, value);
            SymbolStatsEstimate.Builder distinctValuesCount = SymbolStatsEstimate.builder().setNullsFraction(0.0d).setDistinctValuesCount(1.0d);
            if (statsRepresentation.isPresent()) {
                distinctValuesCount.setLowValue(statsRepresentation.getAsDouble());
                distinctValuesCount.setHighValue(statsRepresentation.getAsDouble());
            }
            return distinctValuesCount.build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public SymbolStatsEstimate visitCall(Call call, Void r7) {
            if (call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.NEGATION))) {
                SymbolStatsEstimate process = process((Expression) call.arguments().getFirst());
                return SymbolStatsEstimate.buildFrom(process).setLowValue(-process.getHighValue()).setHighValue(-process.getLowValue()).build();
            }
            if (call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.ADD)) || call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.SUBTRACT)) || call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.MULTIPLY)) || call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.DIVIDE)) || call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.MODULUS))) {
                return processArithmetic(call);
            }
            Expression orElse = IrExpressionOptimizer.newOptimizer(ScalarStatsCalculator.this.plannerContext).process(call, this.session, (Map<Symbol, Expression>) ImmutableMap.of()).orElse(call);
            return ((orElse instanceof Constant) && ((Constant) orElse).value() == null) ? ScalarStatsCalculator.nullStatsEstimate() : orElse instanceof Constant ? SymbolStatsEstimate.builder().setNullsFraction(0.0d).setDistinctValuesCount(1.0d).build() : SymbolStatsEstimate.unknown();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public SymbolStatsEstimate visitCast(Cast cast, Void r7) {
            SymbolStatsEstimate process = process(cast.expression());
            double distinctValuesCount = process.getDistinctValuesCount();
            double lowValue = process.getLowValue();
            double highValue = process.getHighValue();
            if (isIntegralType(cast.type())) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double d = (highValue - lowValue) + 1.0d;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > d) {
                        distinctValuesCount = d;
                    }
                }
            }
            return SymbolStatsEstimate.builder().setNullsFraction(process.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        private boolean isIntegralType(Type type) {
            if ((type instanceof BigintType) || (type instanceof IntegerType) || (type instanceof SmallintType) || (type instanceof TinyintType)) {
                return true;
            }
            return (type instanceof DecimalType) && ((DecimalType) type).getScale() == 0;
        }

        protected SymbolStatsEstimate processArithmetic(Call call) {
            Objects.requireNonNull(call, "node is null");
            SymbolStatsEstimate process = process(call.arguments().get(0));
            SymbolStatsEstimate process2 = process(call.arguments().get(1));
            if (process.isUnknown() || process2.isUnknown()) {
                return SymbolStatsEstimate.unknown();
            }
            SymbolStatsEstimate.Builder distinctValuesCount = SymbolStatsEstimate.builder().setAverageRowSize(Math.max(process.getAverageRowSize(), process2.getAverageRowSize())).setNullsFraction((process.getNullsFraction() + process2.getNullsFraction()) - (process.getNullsFraction() * process2.getNullsFraction())).setDistinctValuesCount(MoreMath.min(process.getDistinctValuesCount() * process2.getDistinctValuesCount(), this.input.getOutputRowCount()));
            double lowValue = process.getLowValue();
            double highValue = process.getHighValue();
            double lowValue2 = process2.getLowValue();
            double highValue2 = process2.getHighValue();
            if (Double.isNaN(lowValue) || Double.isNaN(highValue) || Double.isNaN(lowValue2) || Double.isNaN(highValue2)) {
                distinctValuesCount.setLowValue(Double.NaN).setHighValue(Double.NaN);
            } else if (call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.DIVIDE)) && lowValue2 < 0.0d && highValue2 > 0.0d) {
                distinctValuesCount.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.MODULUS))) {
                double max = MoreMath.max(Math.abs(lowValue2), Math.abs(highValue2));
                if (highValue <= 0.0d) {
                    distinctValuesCount.setLowValue(MoreMath.max(-max, lowValue)).setHighValue(0.0d);
                } else if (lowValue >= 0.0d) {
                    distinctValuesCount.setLowValue(0.0d).setHighValue(MoreMath.min(max, highValue));
                } else {
                    distinctValuesCount.setLowValue(MoreMath.max(-max, lowValue)).setHighValue(MoreMath.min(max, highValue));
                }
            } else {
                double operate = operate(call.function().name(), lowValue, lowValue2);
                double operate2 = operate(call.function().name(), lowValue, highValue2);
                double operate3 = operate(call.function().name(), highValue, lowValue2);
                double operate4 = operate(call.function().name(), highValue, highValue2);
                distinctValuesCount.setLowValue(MoreMath.min(operate, operate2, operate3, operate4)).setHighValue(MoreMath.max(operate, operate2, operate3, operate4));
            }
            return distinctValuesCount.build();
        }

        private double operate(CatalogSchemaFunctionName catalogSchemaFunctionName, double d, double d2) {
            Objects.requireNonNull(catalogSchemaFunctionName);
            int i = 0;
            while (true) {
                switch ((int) SwitchBootstraps.typeSwitch(MethodHandles.lookup(), "typeSwitch", MethodType.methodType(Integer.TYPE, Object.class, Integer.TYPE), CatalogSchemaFunctionName.class, CatalogSchemaFunctionName.class, CatalogSchemaFunctionName.class, CatalogSchemaFunctionName.class, CatalogSchemaFunctionName.class).dynamicInvoker().invoke(catalogSchemaFunctionName, i) /* invoke-custom */) {
                    case 0:
                        if (!catalogSchemaFunctionName.equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.ADD))) {
                            i = 1;
                            break;
                        } else {
                            return d + d2;
                        }
                    case 1:
                        if (!catalogSchemaFunctionName.equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.SUBTRACT))) {
                            i = 2;
                            break;
                        } else {
                            return d - d2;
                        }
                    case 2:
                        if (!catalogSchemaFunctionName.equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.MULTIPLY))) {
                            i = 3;
                            break;
                        } else {
                            return d * d2;
                        }
                    case 3:
                        if (!catalogSchemaFunctionName.equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.DIVIDE))) {
                            i = 4;
                            break;
                        } else {
                            return d / d2;
                        }
                    case 4:
                        if (!catalogSchemaFunctionName.equals(GlobalFunctionCatalog.builtinFunctionName(OperatorType.MODULUS))) {
                            i = 5;
                            break;
                        } else {
                            return d % d2;
                        }
                    default:
                        throw new IllegalStateException("Unsupported binary arithmetic operation: " + String.valueOf(catalogSchemaFunctionName));
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public SymbolStatsEstimate visitCoalesce(Coalesce coalesce, Void r6) {
            Objects.requireNonNull(coalesce, "node is null");
            SymbolStatsEstimate symbolStatsEstimate = null;
            Iterator<Expression> it = coalesce.operands().iterator();
            while (it.hasNext()) {
                SymbolStatsEstimate process = process(it.next());
                symbolStatsEstimate = symbolStatsEstimate != null ? estimateCoalesce(symbolStatsEstimate, process) : process;
            }
            return (SymbolStatsEstimate) Objects.requireNonNull(symbolStatsEstimate, "result is null");
        }

        private SymbolStatsEstimate estimateCoalesce(SymbolStatsEstimate symbolStatsEstimate, SymbolStatsEstimate symbolStatsEstimate2) {
            return symbolStatsEstimate.getNullsFraction() == 0.0d ? symbolStatsEstimate : symbolStatsEstimate.getNullsFraction() == 1.0d ? symbolStatsEstimate2 : SymbolStatsEstimate.builder().setLowValue(MoreMath.min(symbolStatsEstimate.getLowValue(), symbolStatsEstimate2.getLowValue())).setHighValue(MoreMath.max(symbolStatsEstimate.getHighValue(), symbolStatsEstimate2.getHighValue())).setDistinctValuesCount(symbolStatsEstimate.getDistinctValuesCount() + MoreMath.min(symbolStatsEstimate2.getDistinctValuesCount(), this.input.getOutputRowCount() * symbolStatsEstimate.getNullsFraction())).setNullsFraction(symbolStatsEstimate.getNullsFraction() * symbolStatsEstimate2.getNullsFraction()).setAverageRowSize(MoreMath.max(symbolStatsEstimate.getAverageRowSize(), symbolStatsEstimate2.getAverageRowSize())).build();
        }
    }

    @Inject
    public ScalarStatsCalculator(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext cannot be null");
    }

    public SymbolStatsEstimate calculate(Expression expression, PlanNodeStatsEstimate planNodeStatsEstimate, Session session) {
        return new Visitor(planNodeStatsEstimate, session).process(expression);
    }

    private static SymbolStatsEstimate nullStatsEstimate() {
        return SymbolStatsEstimate.builder().setDistinctValuesCount(0.0d).setNullsFraction(1.0d).build();
    }
}
