package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.spi.Plugin;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.RealType;
import io.prestosql.sql.planner.assertions.PlanMatchPattern;
import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest;
import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder;
import io.prestosql.sql.planner.plan.AggregationNode;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TestSingleDistinctAggregationToGroupBy.class */
public class TestSingleDistinctAggregationToGroupBy extends BaseRuleTest {
    public TestSingleDistinctAggregationToGroupBy() {
        super(new Plugin[0]);
    }

    @Test
    public void testNoDistinct() {
        tester().assertThat(new SingleDistinctAggregationToGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(input1)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input1"), planBuilder.symbol("input2")));
            });
        }).doesNotFire();
    }

    @Test
    public void testMultipleDistincts() {
        tester().assertThat(new SingleDistinctAggregationToGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(DISTINCT input1)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("count(DISTINCT input2)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input1"), planBuilder.symbol("input2")));
            });
        }).doesNotFire();
    }

    @Test
    public void testMixedDistinctAndNonDistinct() {
        tester().assertThat(new SingleDistinctAggregationToGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(DISTINCT input1)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("count(input2)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input1"), planBuilder.symbol("input2")));
            });
        }).doesNotFire();
    }

    @Test
    public void testDistinctWithFilter() {
        tester().assertThat(new SingleDistinctAggregationToGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output"), PlanBuilder.expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input1"), planBuilder.symbol("input2")));
            });
        }).doesNotFire();
    }

    @Test
    public void testSingleAggregation() {
        tester().assertThat(new SingleDistinctAggregationToGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output"), PlanBuilder.expression("count(DISTINCT input)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input")));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("output"), PlanMatchPattern.functionCall("count", ImmutableList.of("input"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("input"), ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values("input"))));
    }

    @Test
    public void testMultipleAggregations() {
        tester().assertThat(new SingleDistinctAggregationToGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(DISTINCT input)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("sum(DISTINCT input)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input")));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.builder().put(Optional.of("output1"), PlanMatchPattern.functionCall("count", ImmutableList.of("input"))).put(Optional.of("output2"), PlanMatchPattern.functionCall("sum", ImmutableList.of("input"))).build(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("input"), ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values("input"))));
    }

    @Test
    public void testMultipleInputs() {
        tester().assertThat(new SingleDistinctAggregationToGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("corr(DISTINCT x, y)"), ImmutableList.of(RealType.REAL, RealType.REAL)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("corr(DISTINCT y, x)"), ImmutableList.of(RealType.REAL, RealType.REAL)).source(planBuilder.values(planBuilder.symbol("x"), planBuilder.symbol("y")));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.builder().put(Optional.of("output1"), PlanMatchPattern.functionCall("corr", ImmutableList.of("x", "y"))).put(Optional.of("output2"), PlanMatchPattern.functionCall("corr", ImmutableList.of("y", "x"))).build(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("x", "y"), ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values("x", "y"))));
    }
}
