package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.connector.MockConnectorColumnHandle;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.metadata.InMemoryNodeManager;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.spi.connector.BucketFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorBucketNodeMap;
import io.trino.spi.connector.ConnectorNodePartitioningProvider;
import io.trino.spi.connector.ConnectorPartitioningHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorSplit;
import io.trino.spi.connector.ConnectorTablePartitioning;
import io.trino.spi.connector.ConnectorTableProperties;
import io.trino.spi.connector.ConnectorTransactionHandle;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.LogicalPlanner;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.ExpectedValueProvider;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.tree.FunctionCall;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.ToIntFunction;
import org.assertj.core.api.Assertions;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/TestTableScanNodePartitioning.class */
public class TestTableScanNodePartitioning extends BasePlanTest {
    public static final int BUCKET_COUNT = 10;
    public static final String PARTITIONED_TABLE = "partitioned_table";
    public static final String SINGLE_BUCKET_TABLE = "single_bucket_table";
    public static final String FIXED_PARTITIONED_TABLE = "fixed_partitioned_table";
    public static final String UNPARTITIONED_TABLE = "unpartitioned_table";
    public static final String TEST_SCHEMA = "test_schema";
    public static final Session ENABLE_PLAN_WITH_TABLE_NODE_PARTITIONING = TestingSession.testSessionBuilder().setCatalog("test-catalog").setSchema(TEST_SCHEMA).setSystemProperty("use_table_scan_node_partitioning", "true").setSystemProperty("task_concurrency", "2").build();
    public static final Session DISABLE_PLAN_WITH_TABLE_NODE_PARTITIONING = TestingSession.testSessionBuilder().setCatalog("test-catalog").setSchema(TEST_SCHEMA).setSystemProperty("use_table_scan_node_partitioning", "false").setSystemProperty("task_concurrency", "2").build();
    public static final ConnectorPartitioningHandle PARTITIONING_HANDLE = new ConnectorPartitioningHandle() { // from class: io.trino.sql.planner.TestTableScanNodePartitioning.1
    };
    public static final ConnectorPartitioningHandle SINGLE_BUCKET_HANDLE = new ConnectorPartitioningHandle() { // from class: io.trino.sql.planner.TestTableScanNodePartitioning.2
    };
    public static final ConnectorPartitioningHandle FIXED_PARTITIONING_HANDLE = new ConnectorPartitioningHandle() { // from class: io.trino.sql.planner.TestTableScanNodePartitioning.3
    };
    public static final String COLUMN_A = "column_a";
    public static final ColumnHandle COLUMN_HANDLE_A = new MockConnectorColumnHandle(COLUMN_A, BigintType.BIGINT);
    public static final String COLUMN_B = "column_b";
    public static final ColumnHandle COLUMN_HANDLE_B = new MockConnectorColumnHandle(COLUMN_B, VarcharType.VARCHAR);

    /* loaded from: input_file:io/trino/sql/planner/TestTableScanNodePartitioning$TestPartitioningProvider.class */
    public static class TestPartitioningProvider implements ConnectorNodePartitioningProvider {
        private final InternalNodeManager nodeManager;

        public TestPartitioningProvider(InternalNodeManager internalNodeManager) {
            this.nodeManager = (InternalNodeManager) Objects.requireNonNull(internalNodeManager, "nodeManager is null");
        }

        public Optional<ConnectorBucketNodeMap> getBucketNodeMapping(ConnectorTransactionHandle connectorTransactionHandle, ConnectorSession connectorSession, ConnectorPartitioningHandle connectorPartitioningHandle) {
            if (connectorPartitioningHandle.equals(TestTableScanNodePartitioning.PARTITIONING_HANDLE)) {
                return Optional.of(ConnectorBucketNodeMap.createBucketNodeMap(10));
            }
            if (connectorPartitioningHandle.equals(TestTableScanNodePartitioning.SINGLE_BUCKET_HANDLE)) {
                return Optional.of(ConnectorBucketNodeMap.createBucketNodeMap(1));
            }
            if (connectorPartitioningHandle.equals(TestTableScanNodePartitioning.FIXED_PARTITIONING_HANDLE)) {
                return Optional.of(ConnectorBucketNodeMap.createBucketNodeMap(ImmutableList.of(this.nodeManager.getCurrentNode())));
            }
            throw new IllegalArgumentException();
        }

        public ToIntFunction<ConnectorSplit> getSplitBucketFunction(ConnectorTransactionHandle connectorTransactionHandle, ConnectorSession connectorSession, ConnectorPartitioningHandle connectorPartitioningHandle) {
            throw new UnsupportedOperationException();
        }

        public BucketFunction getBucketFunction(ConnectorTransactionHandle connectorTransactionHandle, ConnectorSession connectorSession, ConnectorPartitioningHandle connectorPartitioningHandle, List<Type> list, int i) {
            throw new UnsupportedOperationException();
        }
    }

    @Override // io.trino.sql.planner.assertions.BasePlanTest
    protected LocalQueryRunner createLocalQueryRunner() {
        LocalQueryRunner build = LocalQueryRunner.builder(TestingSession.testSessionBuilder().setCatalog("test-catalog").setSchema(TEST_SCHEMA).setSystemProperty("task_concurrency", "2").build()).withNodeCountForStats(10).build();
        build.createCatalog("test-catalog", createMockFactory(), ImmutableMap.of());
        return build;
    }

    @Test
    public void testEnablePlanWithTableNodePartitioning() {
        assertTableScanPlannedWithPartitioning(ENABLE_PLAN_WITH_TABLE_NODE_PARTITIONING, PARTITIONED_TABLE, PARTITIONING_HANDLE);
    }

    @Test
    public void testDisablePlanWithTableNodePartitioning() {
        assertTableScanPlannedWithoutPartitioning(DISABLE_PLAN_WITH_TABLE_NODE_PARTITIONING, PARTITIONED_TABLE);
    }

    @Test
    public void testTableScanWithoutConnectorPartitioning() {
        assertTableScanPlannedWithoutPartitioning(ENABLE_PLAN_WITH_TABLE_NODE_PARTITIONING, UNPARTITIONED_TABLE);
    }

    @Test
    public void testTableScanWithFixedConnectorPartitioning() {
        assertTableScanPlannedWithPartitioning(DISABLE_PLAN_WITH_TABLE_NODE_PARTITIONING, FIXED_PARTITIONED_TABLE, FIXED_PARTITIONING_HANDLE);
    }

    @Test
    public void testTableScanWithInsufficientBucketToTaskRatio() {
        assertTableScanPlannedWithoutPartitioning(ENABLE_PLAN_WITH_TABLE_NODE_PARTITIONING, SINGLE_BUCKET_TABLE);
    }

    void assertTableScanPlannedWithPartitioning(Session session, String str, ConnectorPartitioningHandle connectorPartitioningHandle) {
        String str2 = "SELECT count(column_b) FROM " + str + " GROUP BY column_a";
        assertDistributedPlan(str2, session, PlanMatchPattern.anyTree(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>) ImmutableMap.of("COUNT", PlanMatchPattern.functionCall("count", ImmutableList.of("COUNT_PART"))), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.project(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>) ImmutableMap.of("COUNT_PART", PlanMatchPattern.functionCall("count", ImmutableList.of("B"))), AggregationNode.Step.PARTIAL, PlanMatchPattern.tableScan(str, ImmutableMap.of("A", COLUMN_A, "B", COLUMN_B))))))));
        SubPlan subplan = subplan(str2, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false, session);
        Assertions.assertThat(subplan.getAllFragments()).hasSize(1);
        Assertions.assertThat(((PlanFragment) subplan.getAllFragments().get(0)).getPartitioning().getConnectorHandle()).isEqualTo(connectorPartitioningHandle);
    }

    void assertTableScanPlannedWithoutPartitioning(Session session, String str) {
        assertDistributedPlan("SELECT count(column_b) FROM " + str + " GROUP BY column_a", session, PlanMatchPattern.anyTree(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>) ImmutableMap.of("COUNT", PlanMatchPattern.functionCall("count", ImmutableList.of("COUNT_PART"))), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, PlanMatchPattern.project(PlanMatchPattern.aggregation((Map<String, ExpectedValueProvider<FunctionCall>>) ImmutableMap.of("COUNT_PART", PlanMatchPattern.functionCall("count", ImmutableList.of("B"))), AggregationNode.Step.PARTIAL, PlanMatchPattern.tableScan(str, ImmutableMap.of("A", COLUMN_A, "B", COLUMN_B)))))))));
        SubPlan subplan = subplan("SELECT count(column_b) FROM " + str + " GROUP BY column_a", LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, false, session);
        Assertions.assertThat(subplan.getAllFragments()).hasSize(2);
        Assertions.assertThat(((PlanFragment) subplan.getAllFragments().get(1)).getPartitioning().getConnectorHandle()).isEqualTo(SystemPartitioningHandle.SOURCE_DISTRIBUTION.getConnectorHandle());
    }

    public static MockConnectorFactory createMockFactory() {
        return MockConnectorFactory.builder().withPartitionProvider(new TestPartitioningProvider(new InMemoryNodeManager(new InternalNode[0]))).withGetColumns(schemaTableName -> {
            return ImmutableList.of(new ColumnMetadata(COLUMN_A, BigintType.BIGINT), new ColumnMetadata(COLUMN_B, VarcharType.VARCHAR));
        }).withGetTableProperties((connectorSession, connectorTableHandle) -> {
            String tableName = ((MockConnectorTableHandle) connectorTableHandle).getTableName().getTableName();
            return tableName.equals(PARTITIONED_TABLE) ? new ConnectorTableProperties(TupleDomain.all(), Optional.of(new ConnectorTablePartitioning(PARTITIONING_HANDLE, ImmutableList.of(COLUMN_HANDLE_A))), Optional.empty(), ImmutableList.of()) : tableName.equals(SINGLE_BUCKET_TABLE) ? new ConnectorTableProperties(TupleDomain.all(), Optional.of(new ConnectorTablePartitioning(SINGLE_BUCKET_HANDLE, ImmutableList.of(COLUMN_HANDLE_A))), Optional.empty(), ImmutableList.of()) : tableName.equals(FIXED_PARTITIONED_TABLE) ? new ConnectorTableProperties(TupleDomain.all(), Optional.of(new ConnectorTablePartitioning(FIXED_PARTITIONING_HANDLE, ImmutableList.of(COLUMN_HANDLE_A))), Optional.empty(), ImmutableList.of()) : new ConnectorTableProperties();
        }).build();
    }
}
