package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.class */
public class TestTransformCorrelatedGlobalAggregationWithoutProjection extends BaseRuleTest {
    public TestTransformCorrelatedGlobalAggregationWithoutProjection() {
        super(new Plugin[0]);
    }

    @Test
    public void doesNotFireOnPlanWithoutCorrelatedJoinNode() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.values(planBuilder.symbol("a"));
        }).doesNotFire();
    }

    @Test
    public void doesNotFireOnCorrelatedWithoutAggregation() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("a")));
        }).doesNotFire();
    }

    @Test
    public void doesNotFireOnUncorrelated() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(), planBuilder.values(planBuilder.symbol("a")), planBuilder.values(planBuilder.symbol("b")));
        }).doesNotFire();
    }

    @Test
    public void doesNotFireOnCorrelatedWithNonScalarAggregation() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BigintType.BIGINT)).singleGroupingSet(planBuilder.symbol("b"));
            }));
        }).doesNotFire();
    }

    @Test
    public void doesNotFireOnMultipleProjections() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.project(Assignments.of(planBuilder.symbol("expr_2"), PlanBuilder.expression("expr - 1")), planBuilder.project(Assignments.of(planBuilder.symbol("expr"), PlanBuilder.expression("sum + 1")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BigintType.BIGINT)).globalGrouping();
            }))));
        }).doesNotFire();
    }

    @Test
    public void rewritesOnSubqueryWithoutProjection() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BigintType.BIGINT)).globalGrouping();
            }));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("sum_1", PlanMatchPattern.expression("sum_1"), "corr", PlanMatchPattern.expression("corr")), PlanMatchPattern.aggregation(ImmutableMap.of("sum_1", PlanMatchPattern.functionCall("sum", ImmutableList.of("a"))), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(), PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("corr", 0))), PlanMatchPattern.project(ImmutableMap.of("non_null", PlanMatchPattern.expression("true")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0, "b", 1)))))));
    }

    @Test
    public void rewritesOnSubqueryWithProjection() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.project(Assignments.of(planBuilder.symbol("expr"), PlanBuilder.expression("sum + 1")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BigintType.BIGINT)).globalGrouping();
            })));
        }).doesNotFire();
    }

    @Test
    public void testSubqueryWithCount() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))).addAggregation(planBuilder.symbol("count_rows"), PlanBuilder.expression("count(*)"), ImmutableList.of()).addAggregation(planBuilder.symbol("count_non_null_values"), PlanBuilder.expression("count(a)"), ImmutableList.of(BigintType.BIGINT)).globalGrouping();
            }));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.aggregation(ImmutableMap.of("count_rows", PlanMatchPattern.functionCall("count", ImmutableList.of()), "count_non_null_values", PlanMatchPattern.functionCall("count", ImmutableList.of("a"))), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(), PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("corr", 0))), PlanMatchPattern.project(ImmutableMap.of("non_null", PlanMatchPattern.expression("true")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0, "b", 1)))))));
    }

    @Test
    public void rewritesOnSubqueryWithDistinct() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of()).globalGrouping().source(planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.singleGroupingSet(planBuilder.symbol("a")).source(planBuilder.filter(PlanBuilder.expression("b > corr"), planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b"))));
                }));
            }));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("corr", PlanMatchPattern.expression("corr"), "sum_agg", PlanMatchPattern.expression("sum_agg"), "count_agg", PlanMatchPattern.expression("count_agg")), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("sum_agg"), PlanMatchPattern.functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), PlanMatchPattern.functionCall("count", ImmutableList.of())), ImmutableList.of(), ImmutableList.of("non_null"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("corr", "unique", "non_null", "a"), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.join(JoinNode.Type.LEFT, (List<ExpectedValueProvider<JoinNode.EquiJoinClause>>) ImmutableList.of(), (Optional<String>) Optional.of("b > corr"), PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values("corr")), PlanMatchPattern.project(ImmutableMap.of("non_null", PlanMatchPattern.expression("true")), PlanMatchPattern.filter("true", PlanMatchPattern.values("a", "b"))))))));
    }

    @Test
    public void testWithPreexistingMask() {
        tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())).on(planBuilder -> {
            return planBuilder.correlatedJoin(ImmutableList.of(planBuilder.symbol("corr")), planBuilder.values(planBuilder.symbol("corr")), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("mask"))).addAggregation(planBuilder.symbol("count_non_null_values"), PlanBuilder.expression("count(a)"), (List<Type>) ImmutableList.of(BigintType.BIGINT), planBuilder.symbol("mask")).globalGrouping();
            }));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("corr", "unique"), ImmutableMap.of(Optional.of("count_non_null_values"), PlanMatchPattern.functionCall("count", ImmutableList.of("a"))), ImmutableList.of(), ImmutableList.of("new_mask"), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project(ImmutableMap.of("new_mask", PlanMatchPattern.expression("mask AND non_null")), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(), PlanMatchPattern.assignUniqueId("unique", PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("corr", 0))), PlanMatchPattern.project(ImmutableMap.of("non_null", PlanMatchPattern.expression("true")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0, "mask", 1))))))));
    }
}
