package io.druid.sql.calcite.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.druid.java.util.common.ISE;
import io.druid.java.util.common.StringUtils;
import io.druid.math.expr.ExprMacroTable;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.DoubleMaxAggregatorFactory;
import io.druid.query.aggregation.DoubleMinAggregatorFactory;
import io.druid.query.aggregation.DoubleSumAggregatorFactory;
import io.druid.query.aggregation.FloatMaxAggregatorFactory;
import io.druid.query.aggregation.FloatMinAggregatorFactory;
import io.druid.query.aggregation.FloatSumAggregatorFactory;
import io.druid.query.aggregation.LongMaxAggregatorFactory;
import io.druid.query.aggregation.LongMinAggregatorFactory;
import io.druid.query.aggregation.LongSumAggregatorFactory;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.post.ArithmeticPostAggregator;
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.filter.AndDimFilter;
import io.druid.query.filter.DimFilter;
import io.druid.segment.column.ValueType;
import io.druid.sql.calcite.aggregation.Aggregation;
import io.druid.sql.calcite.aggregation.ApproxCountDistinctSqlAggregator;
import io.druid.sql.calcite.aggregation.SqlAggregator;
import io.druid.sql.calcite.expression.DruidExpression;
import io.druid.sql.calcite.expression.Expressions;
import io.druid.sql.calcite.filtration.Filtration;
import io.druid.sql.calcite.planner.PlannerContext;
import io.druid.sql.calcite.planner.Rules;
import io.druid.sql.calcite.table.RowSignature;
import java.util.List;
import java.util.function.Function;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;

/* loaded from: input_file:io/druid/sql/calcite/rule/GroupByRules.class */
public class GroupByRules {
    private static final ApproxCountDistinctSqlAggregator APPROX_COUNT_DISTINCT = new ApproxCountDistinctSqlAggregator();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.druid.sql.calcite.rule.GroupByRules$1, reason: invalid class name */
    /* loaded from: input_file:io/druid/sql/calcite/rule/GroupByRules$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$druid$segment$column$ValueType = new int[ValueType.values().length];

        static {
            try {
                $SwitchMap$io$druid$segment$column$ValueType[ValueType.LONG.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$druid$segment$column$ValueType[ValueType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$io$druid$segment$column$ValueType[ValueType.DOUBLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    private GroupByRules() {
    }

    public static Aggregation translateAggregateCall(PlannerContext plannerContext, RowSignature rowSignature, RexBuilder rexBuilder, Project project, AggregateCall aggregateCall, List<Aggregation> list, String str) {
        DimFilter dimFilter;
        ValueType valueType;
        String str2;
        String expression;
        Aggregation create;
        SqlKind kind = aggregateCall.getAggregation().getKind();
        SqlTypeName sqlTypeName = aggregateCall.getType().getSqlTypeName();
        if (aggregateCall.filterArg < 0) {
            dimFilter = null;
        } else {
            if (project == null) {
                return null;
            }
            dimFilter = Expressions.toFilter(plannerContext, rowSignature, (RexNode) project.getChildExps().get(aggregateCall.filterArg));
            if (dimFilter == null) {
                return null;
            }
        }
        if (kind == SqlKind.COUNT && aggregateCall.getArgList().isEmpty()) {
            return Aggregation.create((AggregatorFactory) new CountAggregatorFactory(str)).filter(makeFilter(dimFilter, rowSignature));
        }
        if (aggregateCall.isDistinct()) {
            if (kind == SqlKind.COUNT && plannerContext.getPlannerConfig().isUseApproximateCountDistinct()) {
                return APPROX_COUNT_DISTINCT.toDruidAggregation(str, rowSignature, plannerContext, list, project, aggregateCall, makeFilter(dimFilter, rowSignature));
            }
            return null;
        }
        if (kind != SqlKind.COUNT && kind != SqlKind.SUM && kind != SqlKind.SUM0 && kind != SqlKind.MIN && kind != SqlKind.MAX && kind != SqlKind.AVG) {
            SqlAggregator lookupAggregator = plannerContext.getOperatorTable().lookupAggregator(aggregateCall.getAggregation());
            if (lookupAggregator != null) {
                return lookupAggregator.toDruidAggregation(str, rowSignature, plannerContext, list, project, aggregateCall, makeFilter(dimFilter, rowSignature));
            }
            return null;
        }
        RexNode fromFieldAccess = Expressions.fromFieldAccess(rowSignature, project, ((Integer) Iterables.getOnlyElement(aggregateCall.getArgList())).intValue());
        DruidExpression druidExpressionForAggregator = toDruidExpressionForAggregator(plannerContext, rowSignature, fromFieldAccess);
        if (druidExpressionForAggregator == null) {
            return null;
        }
        if (kind == SqlKind.COUNT) {
            if (!fromFieldAccess.getType().isNullable()) {
                return Aggregation.create((AggregatorFactory) new CountAggregatorFactory(str)).filter(makeFilter(dimFilter, rowSignature));
            }
            DimFilter filter = Expressions.toFilter(plannerContext, rowSignature, rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, ImmutableList.of(fromFieldAccess)));
            if (filter == null) {
                throw new ISE("Could not create not-null filter for rexNode[%s]", new Object[]{fromFieldAccess});
            }
            return Aggregation.create((AggregatorFactory) new CountAggregatorFactory(str)).filter(makeFilter(dimFilter == null ? filter : new AndDimFilter(ImmutableList.of(dimFilter, filter)), rowSignature));
        }
        if (SqlTypeName.INT_TYPES.contains(sqlTypeName) || SqlTypeName.TIMESTAMP == sqlTypeName || SqlTypeName.DATE == sqlTypeName) {
            valueType = ValueType.LONG;
        } else if (SqlTypeName.FLOAT == sqlTypeName) {
            valueType = ValueType.FLOAT;
        } else {
            if (!SqlTypeName.FRACTIONAL_TYPES.contains(sqlTypeName)) {
                throw new ISE("Cannot determine aggregation type for SQL operator[%s] type[%s]", new Object[]{aggregateCall.getAggregation().getName(), sqlTypeName});
            }
            valueType = ValueType.DOUBLE;
        }
        ExprMacroTable exprMacroTable = plannerContext.getExprMacroTable();
        if (druidExpressionForAggregator.isDirectColumnAccess()) {
            str2 = druidExpressionForAggregator.getDirectColumn();
            expression = null;
        } else {
            str2 = null;
            expression = druidExpressionForAggregator.getExpression();
        }
        if (kind == SqlKind.SUM || kind == SqlKind.SUM0) {
            create = Aggregation.create(createSumAggregatorFactory(valueType, str, str2, expression, exprMacroTable));
        } else if (kind == SqlKind.MIN) {
            create = Aggregation.create(createMinAggregatorFactory(valueType, str, str2, expression, exprMacroTable));
        } else if (kind == SqlKind.MAX) {
            create = Aggregation.create(createMaxAggregatorFactory(valueType, str, str2, expression, exprMacroTable));
        } else {
            if (kind != SqlKind.AVG) {
                throw new ISE("WTF?! Kind[%s] got into the built-in aggregator path somehow?!", new Object[]{kind});
            }
            String format = StringUtils.format("%s:sum", new Object[]{str});
            String format2 = StringUtils.format("%s:count", new Object[]{str});
            create = Aggregation.create((List<AggregatorFactory>) ImmutableList.of(createSumAggregatorFactory(valueType, format, str2, expression, exprMacroTable), new CountAggregatorFactory(format2)), (PostAggregator) new ArithmeticPostAggregator(str, "quotient", ImmutableList.of(new FieldAccessPostAggregator((String) null, format), new FieldAccessPostAggregator((String) null, format2))));
        }
        return create.filter(makeFilter(dimFilter, rowSignature));
    }

    private static DruidExpression toDruidExpressionForAggregator(PlannerContext plannerContext, RowSignature rowSignature, RexNode rexNode) {
        DruidExpression druidExpression = Expressions.toDruidExpression(plannerContext, rowSignature, rexNode);
        if (druidExpression == null) {
            return null;
        }
        return (!druidExpression.isSimpleExtraction() || (druidExpression.isDirectColumnAccess() && rowSignature.getColumnType(druidExpression.getDirectColumn()) != ValueType.STRING)) ? druidExpression : druidExpression.map(simpleExtraction -> {
            return null;
        }, Function.identity());
    }

    private static DimFilter makeFilter(DimFilter dimFilter, RowSignature rowSignature) {
        if (dimFilter == null) {
            return null;
        }
        return Filtration.create(dimFilter).optimizeFilterOnly(rowSignature).getDimFilter();
    }

    private static AggregatorFactory createSumAggregatorFactory(ValueType valueType, String str, String str2, String str3, ExprMacroTable exprMacroTable) {
        switch (AnonymousClass1.$SwitchMap$io$druid$segment$column$ValueType[valueType.ordinal()]) {
            case Rules.BINDABLE_CONVENTION_RULES /* 1 */:
                return new LongSumAggregatorFactory(str, str2, str3, exprMacroTable);
            case 2:
                return new FloatSumAggregatorFactory(str, str2, str3, exprMacroTable);
            case 3:
                return new DoubleSumAggregatorFactory(str, str2, str3, exprMacroTable);
            default:
                throw new ISE("Cannot create aggregator factory for type[%s]", new Object[]{valueType});
        }
    }

    private static AggregatorFactory createMinAggregatorFactory(ValueType valueType, String str, String str2, String str3, ExprMacroTable exprMacroTable) {
        switch (AnonymousClass1.$SwitchMap$io$druid$segment$column$ValueType[valueType.ordinal()]) {
            case Rules.BINDABLE_CONVENTION_RULES /* 1 */:
                return new LongMinAggregatorFactory(str, str2, str3, exprMacroTable);
            case 2:
                return new FloatMinAggregatorFactory(str, str2, str3, exprMacroTable);
            case 3:
                return new DoubleMinAggregatorFactory(str, str2, str3, exprMacroTable);
            default:
                throw new ISE("Cannot create aggregator factory for type[%s]", new Object[]{valueType});
        }
    }

    private static AggregatorFactory createMaxAggregatorFactory(ValueType valueType, String str, String str2, String str3, ExprMacroTable exprMacroTable) {
        switch (AnonymousClass1.$SwitchMap$io$druid$segment$column$ValueType[valueType.ordinal()]) {
            case Rules.BINDABLE_CONVENTION_RULES /* 1 */:
                return new LongMaxAggregatorFactory(str, str2, str3, exprMacroTable);
            case 2:
                return new FloatMaxAggregatorFactory(str, str2, str3, exprMacroTable);
            case 3:
                return new DoubleMaxAggregatorFactory(str, str2, str3, exprMacroTable);
            default:
                throw new ISE("Cannot create aggregator factory for type[%s]", new Object[]{valueType});
        }
    }
}
