/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.AggregationFunction;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.PushFilterThroughCountAggregation;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;

public class TestPushFilterThroughCountAggregation
extends BaseRuleTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction MODULUS_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MODULUS, (List<? extends Type>)ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT));

    public TestPushFilterThroughCountAggregation() {
        super(new Plugin[0]);
    }

    @Test
    public void testDoesNotFireWithNonGroupedAggregation() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "count"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)), (PlanNode)p.aggregation(builder -> builder.globalGrouping().addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWithMultipleAggregations() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            Symbol avg = p.symbol("avg");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "count"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).addAggregation(avg, PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "g"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT), mask).source((PlanNode)p.values(g, mask))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWithNoAggregations() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            return p.filter((Expression)Booleans.TRUE, (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).source((PlanNode)p.values(g, mask))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWithNoMask() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "count"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of()).source((PlanNode)p.values(g))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWithNoCountAggregation() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "count"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "g"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT), mask).source((PlanNode)p.values(g, mask))));
        }).doesNotFire();
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol avg = p.symbol("avg");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)IntegerType.INTEGER, "avg"), (Expression)new Constant((Type)IntegerType.INTEGER, (Object)0L)), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(avg, PlanBuilder.aggregation("avg", (List<Expression>)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "g"))), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT), mask).source((PlanNode)p.values(g, mask))));
        }).doesNotFire();
    }

    @Test
    public void testFilterPredicateFalse() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Logical(Logical.Operator.AND, (List)ImmutableList.of((Object)new Comparison(Comparison.Operator.LESS_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)), (Object)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask))));
        }).matches(PlanMatchPattern.values("g", "count"));
    }

    @Test
    public void testDoesNotFireWhenFilterPredicateTrue() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)Booleans.TRUE, (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask))));
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenFilterPredicateSatisfiedByAllCountValues() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Logical(Logical.Operator.AND, (List)ImmutableList.of((Object)new Logical(Logical.Operator.OR, (List)ImmutableList.of((Object)new Comparison(Comparison.Operator.LESS_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)), (Object)new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))), (Object)new Comparison(Comparison.Operator.EQUAL, (Expression)new Reference((Type)BigintType.BIGINT, "g"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)))), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask))));
        }).doesNotFire();
    }

    @Test
    public void testPushDownMaskAndRemoveFilter() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask))));
        }).matches(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of((Object)"count", PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of())), PlanMatchPattern.filter((Expression)new Reference((Type)BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask"))));
    }

    @Test
    public void testPushDownMaskAndSimplifyFilter() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Logical(Logical.Operator.AND, (List)ImmutableList.of((Object)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)), (Object)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "g"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)))), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask))));
        }).matches(PlanMatchPattern.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "g"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)), PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of((Object)"count", PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of())), PlanMatchPattern.filter((Expression)new Reference((Type)BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask")))));
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Logical(Logical.Operator.AND, (List)ImmutableList.of((Object)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)), (Object)new Comparison(Comparison.Operator.EQUAL, (Expression)new Call(MODULUS_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "count"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)))), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask))));
        }).matches(PlanMatchPattern.filter((Expression)new Comparison(Comparison.Operator.EQUAL, (Expression)new Call(MODULUS_BIGINT, (List)ImmutableList.of((Object)new Reference((Type)BigintType.BIGINT, "count"), (Object)new Constant((Type)BigintType.BIGINT, (Object)2L))), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)), PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of((Object)"count", PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of())), PlanMatchPattern.filter((Expression)new Reference((Type)BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask")))));
    }

    @Test
    public void testPushDownMaskAndRetainFilter() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithoutProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask))));
        }).matches(PlanMatchPattern.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)), PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of((Object)"count", PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of())), PlanMatchPattern.filter((Expression)new Reference((Type)BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask")))));
    }

    @Test
    public void testWithProject() {
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)), (PlanNode)p.project(Assignments.identity((Symbol[])new Symbol[]{count}), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask)))));
        }).matches(PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"count", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "count"))), PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of((Object)"count", PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of())), PlanMatchPattern.filter((Expression)new Reference((Type)BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask")))));
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Logical(Logical.Operator.AND, (List)ImmutableList.of((Object)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)0L)), (Object)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "g"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)))), (PlanNode)p.project(Assignments.identity((Symbol[])new Symbol[]{count, g}), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask)))));
        }).matches(PlanMatchPattern.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "g"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)), PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"count", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "count")), (Object)"g", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "g"))), PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of((Object)"count", PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of())), PlanMatchPattern.filter((Expression)new Reference((Type)BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask"))))));
        this.tester().assertThat((Rule<?>)new PushFilterThroughCountAggregation.PushFilterThroughCountAggregationWithProject(this.tester().getPlannerContext())).on(p -> {
            Symbol g = p.symbol("g");
            Symbol mask = p.symbol("mask");
            Symbol count = p.symbol("count");
            return p.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)), (PlanNode)p.project(Assignments.identity((Symbol[])new Symbol[]{count}), (PlanNode)p.aggregation(builder -> builder.singleGroupingSet(g).addAggregation(count, PlanBuilder.aggregation("count", (List<Expression>)ImmutableList.of()), (List<Type>)ImmutableList.of(), mask).source((PlanNode)p.values(g, mask)))));
        }).matches(PlanMatchPattern.filter((Expression)new Comparison(Comparison.Operator.GREATER_THAN, (Expression)new Reference((Type)BigintType.BIGINT, "count"), (Expression)new Constant((Type)BigintType.BIGINT, (Object)5L)), PlanMatchPattern.project((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"count", (Object)PlanMatchPattern.expression((Expression)new Reference((Type)BigintType.BIGINT, "count"))), PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<AggregationFunction>>)ImmutableMap.of((Object)"count", PlanMatchPattern.aggregationFunction("count", (List<String>)ImmutableList.of())), PlanMatchPattern.filter((Expression)new Reference((Type)BooleanType.BOOLEAN, "mask"), PlanMatchPattern.values("g", "mask"))))));
    }
}

