package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.PushFilterThroughCountAggregation;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import java.util.List;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushFilterThroughCountAggregation.class */
public class TestPushFilterThroughCountAggregation extends BaseRuleTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction MODULUS_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));

    public TestPushFilterThroughCountAggregation() {
        super(new Plugin[0]);
    }

    @Test
    public void testDoesNotFireWithNonGroupedAggregation() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(IntegerType.INTEGER, "count"), new Constant(IntegerType.INTEGER, 0L)), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWithMultipleAggregations() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            Symbol symbol4 = planBuilder.symbol("avg");
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(IntegerType.INTEGER, "count"), new Constant(IntegerType.INTEGER, 0L)), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).addAggregation(symbol4, PlanBuilder.aggregation("avg", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "g"))), (List<Type>) ImmutableList.of(BigintType.BIGINT), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWithNoAggregations() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            return planBuilder.filter(Booleans.TRUE, planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).source(planBuilder.values(symbol, symbol2));
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWithNoMask() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("count");
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(IntegerType.INTEGER, "count"), new Constant(IntegerType.INTEGER, 0L)), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol2, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), ImmutableList.of()).source(planBuilder.values(symbol));
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWithNoCountAggregation() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(IntegerType.INTEGER, "count"), new Constant(IntegerType.INTEGER, 0L)), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "g"))), (List<Type>) ImmutableList.of(BigintType.BIGINT), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).doesNotFire();
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("g");
            Symbol symbol2 = planBuilder2.symbol("mask");
            Symbol symbol3 = planBuilder2.symbol("avg");
            return planBuilder2.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(IntegerType.INTEGER, "avg"), new Constant(IntegerType.INTEGER, 0L)), planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("avg", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "g"))), (List<Type>) ImmutableList.of(BigintType.BIGINT), symbol2).source(planBuilder2.values(symbol, symbol2));
            }));
        }).doesNotFire();
    }

    @Test
    public void testFilterPredicateFalse() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)))), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).matches(PlanMatchPattern.values("g", "count"));
    }

    @Test
    public void testDoesNotFireWhenFilterPredicateTrue() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(Booleans.TRUE, planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenFilterPredicateSatisfiedByAllCountValues() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Logical(Logical.Operator.OR, ImmutableList.of(new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)), new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)))), new Comparison(Comparison.Operator.EQUAL, new Reference(BigintType.BIGINT, "g"), new Constant(BigintType.BIGINT, 5L)))), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).doesNotFire();
    }

    @Test
    public void testPushDownMaskAndRemoveFilter() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).matches(PlanMatchPattern.aggregation(ImmutableMap.of("count", PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), PlanMatchPattern.filter(new Reference(BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask"))));
    }

    @Test
    public void testPushDownMaskAndSimplifyFilter() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "g"), new Constant(BigintType.BIGINT, 5L)))), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).matches(PlanMatchPattern.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "g"), new Constant(BigintType.BIGINT, 5L)), PlanMatchPattern.aggregation(ImmutableMap.of("count", PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), PlanMatchPattern.filter(new Reference(BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask")))));
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("g");
            Symbol symbol2 = planBuilder2.symbol("mask");
            Symbol symbol3 = planBuilder2.symbol("count");
            return planBuilder2.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)), new Comparison(Comparison.Operator.EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 2L))), new Constant(BigintType.BIGINT, 0L)))), planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder2.values(symbol, symbol2));
            }));
        }).matches(PlanMatchPattern.filter(new Comparison(Comparison.Operator.EQUAL, new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 2L))), new Constant(BigintType.BIGINT, 0L)), PlanMatchPattern.aggregation(ImmutableMap.of("count", PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), PlanMatchPattern.filter(new Reference(BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask")))));
    }

    @Test
    public void testPushDownMaskAndRetainFilter() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 5L)), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder.values(symbol, symbol2));
            }));
        }).matches(PlanMatchPattern.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 5L)), PlanMatchPattern.aggregation(ImmutableMap.of("count", PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), PlanMatchPattern.filter(new Reference(BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask")))));
    }

    @Test
    public void testWithProject() {
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithProject(tester().getPlannerContext())).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("g");
            Symbol symbol2 = planBuilder.symbol("mask");
            Symbol symbol3 = planBuilder.symbol("count");
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)), planBuilder.project(Assignments.identity(new Symbol[]{symbol3}), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder.values(symbol, symbol2));
            })));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("count", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "count"))), PlanMatchPattern.aggregation(ImmutableMap.of("count", PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), PlanMatchPattern.filter(new Reference(BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask")))));
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithProject(tester().getPlannerContext())).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("g");
            Symbol symbol2 = planBuilder2.symbol("mask");
            Symbol symbol3 = planBuilder2.symbol("count");
            return planBuilder2.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 0L)), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "g"), new Constant(BigintType.BIGINT, 5L)))), planBuilder2.project(Assignments.identity(new Symbol[]{symbol3, symbol}), planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder2.values(symbol, symbol2));
            })));
        }).matches(PlanMatchPattern.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "g"), new Constant(BigintType.BIGINT, 5L)), PlanMatchPattern.project(ImmutableMap.of("count", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "count")), "g", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "g"))), PlanMatchPattern.aggregation(ImmutableMap.of("count", PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), PlanMatchPattern.filter(new Reference(BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask"))))));
        tester().assertThat(new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithProject(tester().getPlannerContext())).on(planBuilder3 -> {
            Symbol symbol = planBuilder3.symbol("g");
            Symbol symbol2 = planBuilder3.symbol("mask");
            Symbol symbol3 = planBuilder3.symbol("count");
            return planBuilder3.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 5L)), planBuilder3.project(Assignments.identity(new Symbol[]{symbol3}), planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).addAggregation(symbol3, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of()), (List<Type>) ImmutableList.of(), symbol2).source(planBuilder3.values(symbol, symbol2));
            })));
        }).matches(PlanMatchPattern.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "count"), new Constant(BigintType.BIGINT, 5L)), PlanMatchPattern.project(ImmutableMap.of("count", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "count"))), PlanMatchPattern.aggregation(ImmutableMap.of("count", PlanMatchPattern.aggregationFunction("count", ImmutableList.of())), PlanMatchPattern.filter(new Reference(BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask"))))));
    }
}
