package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.prestosql.Session;
import io.prestosql.plugin.tpch.TpchConnectorFactory;
import io.prestosql.sql.analyzer.FeaturesConfig;
import io.prestosql.sql.planner.assertions.BasePlanTest;
import io.prestosql.sql.planner.assertions.ExpectedValueProvider;
import io.prestosql.sql.planner.assertions.PlanMatchPattern;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.MarkDistinctNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.testing.LocalQueryRunner;
import io.prestosql.testing.TestingSession;
import java.util.List;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/TestAddExchangesPlans.class */
public class TestAddExchangesPlans extends BasePlanTest {
    public TestAddExchangesPlans() {
        super(TestAddExchangesPlans::createQueryRunner);
    }

    private static LocalQueryRunner createQueryRunner() {
        Session build = TestingSession.testSessionBuilder().setCatalog("tpch").setSchema("tiny").build();
        LocalQueryRunner build2 = LocalQueryRunner.builder(build).withFeaturesConfig(new FeaturesConfig().setSpillerSpillPaths("/tmp/test_spill_path")).build();
        build2.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of());
        return build2;
    }

    @Test
    public void testRepartitionForUnionWithAnyTableScans() {
        assertDistributedPlan("SELECT nationkey FROM nation UNION select regionkey from region", PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(ImmutableMap.of(), PlanMatchPattern.anyTree(PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("nation")))), PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("region"))))))));
        assertDistributedPlan("SELECT nationkey FROM nation UNION select 1", PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(ImmutableMap.of(), PlanMatchPattern.anyTree(PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("nation")))), PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.values(new String[0]))))))));
    }

    @Test
    public void testRepartitionForUnionAllBeforeHashJoin() {
        Session build = Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("join_distribution_type", JoinNode.DistributionType.PARTITIONED.name()).setSystemProperty("join_reordering_strategy", FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS.name()).build();
        assertDistributedPlan("SELECT * FROM (SELECT nationkey FROM nation UNION ALL select nationkey from nation) n join region r on n.nationkey = r.regionkey", build, PlanMatchPattern.anyTree(PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(PlanMatchPattern.equiJoinClause("nationkey", "regionkey")), PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("nation")))), PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("region", ImmutableMap.of("regionkey", "regionkey"))))))));
        assertDistributedPlan("SELECT * FROM (SELECT nationkey FROM nation UNION ALL select 1) n join region r on n.nationkey = r.regionkey", build, PlanMatchPattern.anyTree(PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(PlanMatchPattern.equiJoinClause("nationkey", "regionkey")), PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.values(new String[0])))), PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("region", ImmutableMap.of("regionkey", "regionkey"))))))));
    }

    @Test
    public void testNonSpillableBroadcastJoinAboveTableScan() {
        assertDistributedPlan("SELECT * FROM nation n join region r on n.nationkey = r.regionkey", noJoinReordering(), PlanMatchPattern.anyTree(PlanMatchPattern.join(JoinNode.Type.INNER, (List<ExpectedValueProvider<JoinNode.EquiJoinClause>>) ImmutableList.of(PlanMatchPattern.equiJoinClause("nationkey", "regionkey")), (Optional<String>) Optional.empty(), (Optional<JoinNode.DistributionType>) Optional.of(JoinNode.DistributionType.REPLICATED), (Optional<Boolean>) Optional.of(false), PlanMatchPattern.anyNot(ExchangeNode.class, PlanMatchPattern.tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))), PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPLICATE, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("region", ImmutableMap.of("regionkey", "regionkey"))))))));
        assertDistributedPlan("SELECT * FROM nation n join region r on n.nationkey = r.regionkey", spillEnabledWithJoinDistributionType(FeaturesConfig.JoinDistributionType.PARTITIONED), PlanMatchPattern.anyTree(PlanMatchPattern.join(JoinNode.Type.INNER, (List<ExpectedValueProvider<JoinNode.EquiJoinClause>>) ImmutableList.of(PlanMatchPattern.equiJoinClause("nationkey", "regionkey")), (Optional<String>) Optional.empty(), (Optional<JoinNode.DistributionType>) Optional.of(JoinNode.DistributionType.PARTITIONED), (Optional<Boolean>) Optional.empty(), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))), PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("region", ImmutableMap.of("regionkey", "regionkey"))))))));
    }

    @Test
    public void testForcePartitioningMarkDistinctInput() {
        assertDistributedPlan("SELECT count(orderkey), count(distinct orderkey), custkey , count(1) FROM ( SELECT * FROM (VALUES (1, 2)) as t(custkey, orderkey) UNION ALL SELECT 3, 4) GROUP BY 3", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("ignore_downstream_preferences", "true").build(), PlanMatchPattern.anyTree(PlanMatchPattern.node(MarkDistinctNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, ImmutableList.of(), ImmutableSet.of("partition1", "partition2"), PlanMatchPattern.anyTree(PlanMatchPattern.values((List<String>) ImmutableList.of("partition1", "partition2")))), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, ImmutableList.of(), ImmutableSet.of("partition3", "partition3"), PlanMatchPattern.project(PlanMatchPattern.project(ImmutableMap.of("partition3", PlanMatchPattern.expression("3"), "partition4", PlanMatchPattern.expression("4")), PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0])))))))));
        assertDistributedPlan("SELECT count(orderkey), count(distinct orderkey), custkey , count(1) FROM ( SELECT * FROM (VALUES (1, 2)) as t(custkey, orderkey) UNION ALL SELECT 3, 4) GROUP BY 3", Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("ignore_downstream_preferences", "false").build(), PlanMatchPattern.anyTree(PlanMatchPattern.node(MarkDistinctNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, ImmutableList.of(), ImmutableSet.of("partition1"), PlanMatchPattern.anyTree(PlanMatchPattern.values((List<String>) ImmutableList.of("partition1", "partition2")))), PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, ImmutableList.of(), ImmutableSet.of("partition3"), PlanMatchPattern.project(PlanMatchPattern.project(ImmutableMap.of("partition3", PlanMatchPattern.expression("3"), "partition4", PlanMatchPattern.expression("4")), PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0])))))))));
    }

    private Session spillEnabledWithJoinDistributionType(FeaturesConfig.JoinDistributionType joinDistributionType) {
        return Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("join_distribution_type", joinDistributionType.toString()).setSystemProperty("spill_enabled", "true").setSystemProperty("task_concurrency", "16").build();
    }

    private Session noJoinReordering() {
        return Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("join_reordering_strategy", FeaturesConfig.JoinReorderingStrategy.NONE.name()).setSystemProperty("join_distribution_type", FeaturesConfig.JoinDistributionType.BROADCAST.name()).setSystemProperty("spill_enabled", "true").setSystemProperty("task_concurrency", "16").build();
    }
}
