package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.TaskCountEstimator;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.class */
public class TestMultipleDistinctAggregationToMarkDistinct extends BaseRuleTest {
    private static final int NODES_COUNT = 4;
    private static final TaskCountEstimator TASK_COUNT_ESTIMATOR = new TaskCountEstimator(() -> {
        return NODES_COUNT;
    });

    public TestMultipleDistinctAggregationToMarkDistinct() {
        super(new Plugin[0]);
    }

    @Test
    public void testNoDistinct() {
        tester().assertThat(new SingleDistinctAggregationToGroupBy()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(input1)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("count(input2)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input1"), planBuilder.symbol("input2")));
            });
        }).doesNotFire();
    }

    @Test
    public void testSingleDistinct() {
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(DISTINCT input1)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input1"), planBuilder.symbol("input2")));
            });
        }).doesNotFire();
    }

    @Test
    public void testMultipleAggregations() {
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(DISTINCT input)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("sum(DISTINCT input)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input")));
            });
        }).doesNotFire();
    }

    @Test
    public void testDistinctWithFilter() {
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(DISTINCT input1) filter (where filter1)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("count(DISTINCT input2) filter (where filter2)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.project(Assignments.builder().putIdentity(planBuilder.symbol("input1")).putIdentity(planBuilder.symbol("input2")).put(planBuilder.symbol("filter1"), PlanBuilder.expression("input2 > 0")).put(planBuilder.symbol("filter2"), PlanBuilder.expression("input1 > 0")).build(), planBuilder.values(planBuilder.symbol("input1"), planBuilder.symbol("input2"))));
            });
        }).doesNotFire();
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder2.symbol("output1"), PlanBuilder.expression("count(DISTINCT input1) filter (where filter1)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder2.symbol("output2"), PlanBuilder.expression("count(DISTINCT input2)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder2.project(Assignments.builder().putIdentity(planBuilder2.symbol("input1")).putIdentity(planBuilder2.symbol("input2")).put(planBuilder2.symbol("filter1"), PlanBuilder.expression("input2 > 0")).put(planBuilder2.symbol("filter2"), PlanBuilder.expression("input1 > 0")).build(), planBuilder2.values(planBuilder2.symbol("input1"), planBuilder2.symbol("input2"))));
            });
        }).doesNotFire();
    }

    @Test
    public void testGlobalAggregation() {
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(DISTINCT input1)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("count(DISTINCT input2)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input1"), planBuilder.symbol("input2")));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("output1"), PlanMatchPattern.functionCall("count", ImmutableList.of("input1")), Optional.of("output2"), PlanMatchPattern.functionCall("count", ImmutableList.of("input2"))), ImmutableList.of(), ImmutableList.of("mark_input1", "mark_input2"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.markDistinct("mark_input2", ImmutableList.of("input2"), PlanMatchPattern.markDistinct("mark_input1", ImmutableList.of("input1"), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("input1", 0, "input2", 1))))));
    }

    @Test
    public void testAggregationNDV() {
        PlanNodeId planNodeId = new PlanNodeId("aggregationNodeId");
        Function<PlanBuilder, PlanNode> function = planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(planNodeId).singleGroupingSet(planBuilder.symbol("key")).addAggregation(planBuilder.symbol("output1"), PlanBuilder.expression("count(DISTINCT input)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2"), PlanBuilder.expression("sum(input)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("input"), planBuilder.symbol("key")));
            });
        };
        PlanMatchPattern aggregation = PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("key"), ImmutableMap.of(Optional.of("output1"), PlanMatchPattern.functionCall("count", ImmutableList.of("input")), Optional.of("output2"), PlanMatchPattern.functionCall("sum", ImmutableList.of("input"))), ImmutableList.of(), ImmutableList.of("mark_input"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.markDistinct("mark_input", ImmutableList.of("input", "key"), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("input", 0, "key", 1))));
        int intValue = NODES_COUNT * ((Integer) tester().getSession().getSystemProperty("task_concurrency", Integer.class)).intValue();
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(2 * intValue).build()).matches(aggregation);
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(Double.NaN).build()).matches(aggregation);
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).setSystemProperty("optimize_mixed_distinct_aggregations", "true").overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(50 * intValue).build()).matches(aggregation);
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).setSystemProperty("optimize_mixed_distinct_aggregations", "false").overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(50 * intValue).build()).doesNotFire();
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(planNodeId).singleGroupingSet(planBuilder2.symbol("key")).addAggregation(planBuilder2.symbol("output1"), PlanBuilder.expression("count(DISTINCT input1)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder2.symbol("output2"), PlanBuilder.expression("count(DISTINCT input2)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder2.values(planBuilder2.symbol("input1"), planBuilder2.symbol("input2"), planBuilder2.symbol("key")));
            });
        }).setSystemProperty("optimize_mixed_distinct_aggregations", "true").overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(50 * intValue).build()).doesNotFire();
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(1000 * intValue).build()).doesNotFire();
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).setSystemProperty("mark_distinct_strategy", "always").overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(1000 * intValue).build()).matches(aggregation);
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).setSystemProperty("use_mark_distinct", "true").overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(1000 * intValue).build()).doesNotFire();
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).setSystemProperty("mark_distinct_strategy", "none").overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(2 * intValue).build()).doesNotFire();
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(function).setSystemProperty("use_mark_distinct", "false").overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(2 * intValue).build()).doesNotFire();
        tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)).on(planBuilder3 -> {
            return planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(planNodeId).singleGroupingSet(planBuilder3.symbol("key1"), planBuilder3.symbol("key2")).addAggregation(planBuilder3.symbol("output1"), PlanBuilder.expression("count(DISTINCT input)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder3.symbol("output2"), PlanBuilder.expression("sum(input)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder3.values(planBuilder3.symbol("input"), planBuilder3.symbol("key1"), planBuilder3.symbol("key2")));
            });
        }).overrideStats(planNodeId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(1000 * intValue).build()).matches(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("key1", "key2"), ImmutableMap.of(Optional.of("output1"), PlanMatchPattern.functionCall("count", ImmutableList.of("input")), Optional.of("output2"), PlanMatchPattern.functionCall("sum", ImmutableList.of("input"))), ImmutableList.of(), ImmutableList.of("mark_input"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.markDistinct("mark_input", ImmutableList.of("input", "key1", "key2"), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("input", 0, "key1", 1, "key2", 2)))));
    }
}
