package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushPartialAggregationThroughJoin.class */
public class TestPushPartialAggregationThroughJoin extends BaseRuleTest {
    public TestPushPartialAggregationThroughJoin() {
        super(new Plugin[0]);
    }

    @Test
    public void testPushesPartialAggregationThroughJoin() {
        tester().assertThat(new PushPartialAggregationThroughJoin()).setSystemProperty("push_partial_aggregation_through_join", "true").on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinType.INNER, planBuilder.values(planBuilder.symbol("LEFT_EQUI"), planBuilder.symbol("LEFT_NON_EQUI"), planBuilder.symbol("LEFT_GROUP_BY"), planBuilder.symbol("LEFT_AGGR"), planBuilder.symbol("LEFT_HASH")), planBuilder.values(planBuilder.symbol("RIGHT_EQUI"), planBuilder.symbol("RIGHT_NON_EQUI"), planBuilder.symbol("RIGHT_GROUP_BY"), planBuilder.symbol("RIGHT_HASH")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.symbol("LEFT_EQUI"), planBuilder.symbol("RIGHT_EQUI"))), ImmutableList.of(planBuilder.symbol("LEFT_GROUP_BY"), planBuilder.symbol("LEFT_AGGR")), ImmutableList.of(planBuilder.symbol("RIGHT_GROUP_BY")), Optional.of(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, new Reference(BigintType.BIGINT, "LEFT_NON_EQUI"), new Reference(BigintType.BIGINT, "RIGHT_NON_EQUI"))), Optional.of(planBuilder.symbol("LEFT_HASH")), Optional.of(planBuilder.symbol("RIGHT_HASH")))).addAggregation(planBuilder.symbol("AVG", DoubleType.DOUBLE), PlanBuilder.aggregation("AVG", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "LEFT_AGGR"))), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(planBuilder.symbol("LEFT_GROUP_BY"), planBuilder.symbol("RIGHT_GROUP_BY")).step(AggregationNode.Step.PARTIAL);
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("LEFT_GROUP_BY", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "LEFT_GROUP_BY")), "RIGHT_GROUP_BY", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "RIGHT_GROUP_BY")), "AVG", PlanMatchPattern.expression(new Reference(DoubleType.DOUBLE, "AVG"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.equiCriteria("LEFT_EQUI", "RIGHT_EQUI").filter(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, new Reference(BigintType.BIGINT, "LEFT_NON_EQUI"), new Reference(BigintType.BIGINT, "RIGHT_NON_EQUI"))).left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("LEFT_EQUI", "LEFT_NON_EQUI", "LEFT_GROUP_BY", "LEFT_HASH"), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.aggregationFunction("avg", ImmutableList.of("LEFT_AGGR"))), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values("LEFT_EQUI", "LEFT_NON_EQUI", "LEFT_GROUP_BY", "LEFT_AGGR", "LEFT_HASH"))).right(PlanMatchPattern.values("RIGHT_EQUI", "RIGHT_NON_EQUI", "RIGHT_GROUP_BY", "RIGHT_HASH"));
        })));
    }
}
