package io.trino.execution.scheduler;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import io.airlift.concurrent.Threads;
import io.airlift.testing.TestingTicker;
import io.trino.Session;
import io.trino.client.NodeVersion;
import io.trino.execution.MockRemoteTaskFactory;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.RemoteTask;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.scheduler.NodeSchedulerConfig;
import io.trino.execution.scheduler.UniformNodeSelector;
import io.trino.metadata.InMemoryNodeManager;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.spi.connector.CatalogHandle;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingHandles;
import io.trino.testing.TestingSession;
import io.trino.testing.TestingSplit;
import io.trino.util.FinalizerService;
import java.net.URI;
import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(TestInstance.Lifecycle.PER_METHOD)
/* loaded from: input_file:io/trino/execution/scheduler/TestUniformNodeSelector.class */
public class TestUniformNodeSelector {
    private static final InternalNode node1 = new InternalNode("node1", URI.create("http://10.0.0.1:13"), NodeVersion.UNKNOWN, false);
    private static final InternalNode node2 = new InternalNode("node2", URI.create("http://10.0.0.1:12"), NodeVersion.UNKNOWN, false);
    private final Set<Split> splits = new LinkedHashSet();
    private FinalizerService finalizerService;
    private NodeTaskMap nodeTaskMap;
    private InMemoryNodeManager nodeManager;
    private NodeSchedulerConfig nodeSchedulerConfig;
    private NodeScheduler nodeScheduler;
    private NodeSelector nodeSelector;
    private Map<InternalNode, RemoteTask> taskMap;
    private ExecutorService remoteTaskExecutor;
    private ScheduledExecutorService remoteTaskScheduledExecutor;
    private Session session;

    @BeforeEach
    public void setUp() {
        this.session = TestingSession.testSessionBuilder().build();
        this.finalizerService = new FinalizerService();
        this.nodeTaskMap = new NodeTaskMap(this.finalizerService);
        this.nodeManager = new InMemoryNodeManager(new InternalNode[0]);
        this.nodeManager.addNodes(new InternalNode[]{node1});
        this.nodeManager.addNodes(new InternalNode[]{node2});
        this.nodeSchedulerConfig = new NodeSchedulerConfig().setMaxSplitsPerNode(20).setMinPendingSplitsPerTask(10).setMaxAdjustedPendingSplitsWeightPerTask(100).setIncludeCoordinator(false);
        this.nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(this.nodeManager, this.nodeSchedulerConfig, this.nodeTaskMap));
        this.taskMap = new HashMap();
        this.nodeSelector = this.nodeScheduler.createNodeSelector(this.session, Optional.of(TestingHandles.TEST_CATALOG_HANDLE));
        this.remoteTaskExecutor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed("remoteTaskExecutor-%s"));
        this.remoteTaskScheduledExecutor = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed("remoteTaskScheduledExecutor-%s"));
        this.finalizerService.start();
    }

    @AfterEach
    public void tearDown() {
        this.remoteTaskExecutor.shutdown();
        this.remoteTaskExecutor = null;
        this.remoteTaskScheduledExecutor.shutdown();
        this.remoteTaskScheduledExecutor = null;
        this.nodeSchedulerConfig = null;
        this.nodeScheduler = null;
        this.nodeSelector = null;
        this.finalizerService.destroy();
        this.finalizerService = null;
    }

    @Test
    public void testQueueSizeAdjustmentScaleDown() {
        TestingTicker testingTicker = new TestingTicker();
        UniformNodeSelector.QueueSizeAdjuster queueSizeAdjuster = new UniformNodeSelector.QueueSizeAdjuster(10L, 100L, testingTicker);
        this.nodeSelector = new UniformNodeSelector(this.nodeManager, this.nodeTaskMap, false, () -> {
            return createNodeMap(TestingHandles.TEST_CATALOG_HANDLE);
        }, 10, 100L, 10L, 500, NodeSchedulerConfig.SplitsBalancingPolicy.STAGE, false, queueSizeAdjuster);
        for (int i = 0; i < 20; i++) {
            this.splits.add(new Split(TestingHandles.TEST_CATALOG_HANDLE, TestingSplit.createRemoteSplit()));
        }
        Multimap assignments = this.nodeSelector.computeAssignments(this.splits, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        Assertions.assertThat(assignments.size()).isEqualTo(2);
        MockRemoteTaskFactory mockRemoteTaskFactory = new MockRemoteTaskFactory(this.remoteTaskExecutor, this.remoteTaskScheduledExecutor);
        int i2 = 0;
        for (InternalNode internalNode : assignments.keySet()) {
            TaskId taskId = new TaskId(new StageId("test", 1), i2, 0);
            i2++;
            MockRemoteTaskFactory.MockRemoteTask createTableScanTask = mockRemoteTaskFactory.createTableScanTask(taskId, internalNode, ImmutableList.copyOf(assignments.get(internalNode)), this.nodeTaskMap.createPartitionedSplitCountTracker(internalNode, taskId));
            createTableScanTask.startSplits(createTableScanTask.getPartitionedSplitsInfo().getCount());
            this.nodeTaskMap.addTask(internalNode, createTableScanTask);
            this.taskMap.put(internalNode, createTableScanTask);
        }
        Sets.SetView difference = Sets.difference(this.splits, new HashSet(assignments.values()));
        Assertions.assertThat(difference.size()).isEqualTo(18);
        Multimap assignments2 = this.nodeSelector.computeAssignments(difference, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        Assertions.assertThat(assignments2.size()).isEqualTo(2);
        for (InternalNode internalNode2 : assignments2.keySet()) {
            ((MockRemoteTaskFactory.MockRemoteTask) this.taskMap.get(internalNode2)).addSplits(ImmutableMultimap.builder().putAll(new PlanNodeId("sourceId"), assignments2.get(internalNode2)).build());
        }
        long adjustedMaxPendingSplitsWeightPerTask = queueSizeAdjuster.getAdjustedMaxPendingSplitsWeightPerTask(node1.getNodeIdentifier());
        Assertions.assertThat(20).isEqualTo(adjustedMaxPendingSplitsWeightPerTask);
        testingTicker.increment(999L, TimeUnit.MILLISECONDS);
        Assertions.assertThat(this.nodeSelector.computeAssignments(difference, ImmutableList.copyOf(this.taskMap.values())).getAssignments().size()).isEqualTo(0);
        Assertions.assertThat(adjustedMaxPendingSplitsWeightPerTask).isEqualTo(queueSizeAdjuster.getAdjustedMaxPendingSplitsWeightPerTask(node1.getNodeIdentifier()));
        testingTicker.increment(1L, TimeUnit.MILLISECONDS);
        Assertions.assertThat(this.nodeSelector.computeAssignments(difference, ImmutableList.copyOf(this.taskMap.values())).getAssignments().size()).isEqualTo(0);
        Assertions.assertThat(13).isEqualTo(queueSizeAdjuster.getAdjustedMaxPendingSplitsWeightPerTask(node1.getNodeIdentifier()));
    }

    @Test
    public void testQueueSizeAdjustmentAllNodes() {
        for (int i = 0; i < 180; i++) {
            this.splits.add(new Split(TestingHandles.TEST_CATALOG_HANDLE, TestingSplit.createRemoteSplit()));
        }
        Multimap assignments = this.nodeSelector.computeAssignments(this.splits, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        Assertions.assertThat(assignments.size()).isEqualTo(40);
        MockRemoteTaskFactory mockRemoteTaskFactory = new MockRemoteTaskFactory(this.remoteTaskExecutor, this.remoteTaskScheduledExecutor);
        int i2 = 0;
        for (InternalNode internalNode : assignments.keySet()) {
            TaskId taskId = new TaskId(new StageId("test", 1), i2, 0);
            i2++;
            MockRemoteTaskFactory.MockRemoteTask createTableScanTask = mockRemoteTaskFactory.createTableScanTask(taskId, internalNode, ImmutableList.copyOf(assignments.get(internalNode)), this.nodeTaskMap.createPartitionedSplitCountTracker(internalNode, taskId));
            createTableScanTask.startSplits(createTableScanTask.getPartitionedSplitsInfo().getCount());
            this.nodeTaskMap.addTask(internalNode, createTableScanTask);
            this.taskMap.put(internalNode, createTableScanTask);
        }
        Sets.SetView difference = Sets.difference(this.splits, new HashSet(assignments.values()));
        Assertions.assertThat(difference.size()).isEqualTo(140);
        Multimap assignments2 = this.nodeSelector.computeAssignments(difference, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        for (InternalNode internalNode2 : assignments2.keySet()) {
            MockRemoteTaskFactory.MockRemoteTask mockRemoteTask = (MockRemoteTaskFactory.MockRemoteTask) this.taskMap.get(internalNode2);
            mockRemoteTask.addSplits(ImmutableMultimap.builder().putAll(new PlanNodeId("sourceId"), assignments2.get(internalNode2)).build());
            mockRemoteTask.startSplits(mockRemoteTask.getPartitionedSplitsInfo().getCount());
        }
        Sets.SetView difference2 = Sets.difference(difference, new HashSet(assignments2.values()));
        Assertions.assertThat(difference2.size()).isEqualTo(100);
        Multimap assignments3 = this.nodeSelector.computeAssignments(difference2, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        for (InternalNode internalNode3 : assignments3.keySet()) {
            this.taskMap.get(internalNode3).addSplits(ImmutableMultimap.builder().putAll(new PlanNodeId("sourceId"), assignments3.get(internalNode3)).build());
        }
        Sets.SetView difference3 = Sets.difference(difference2, new HashSet(assignments3.values()));
        Assertions.assertThat(difference3.size()).isEqualTo(20);
        Assertions.assertThat(Sets.difference(difference3, new HashSet(this.nodeSelector.computeAssignments(difference3, ImmutableList.copyOf(this.taskMap.values())).getAssignments().values())).size()).isEqualTo(20);
    }

    @Test
    public void testQueueSizeAdjustmentOneOfAll() {
        for (int i = 0; i < 180; i++) {
            this.splits.add(new Split(TestingHandles.TEST_CATALOG_HANDLE, TestingSplit.createRemoteSplit()));
        }
        Multimap assignments = this.nodeSelector.computeAssignments(this.splits, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        Assertions.assertThat(assignments.size()).isEqualTo(40);
        MockRemoteTaskFactory mockRemoteTaskFactory = new MockRemoteTaskFactory(this.remoteTaskExecutor, this.remoteTaskScheduledExecutor);
        int i2 = 0;
        for (InternalNode internalNode : assignments.keySet()) {
            TaskId taskId = new TaskId(new StageId("test", 1), i2, 0);
            i2++;
            MockRemoteTaskFactory.MockRemoteTask createTableScanTask = mockRemoteTaskFactory.createTableScanTask(taskId, internalNode, ImmutableList.copyOf(assignments.get(internalNode)), this.nodeTaskMap.createPartitionedSplitCountTracker(internalNode, taskId));
            if (internalNode.equals(node1)) {
                createTableScanTask.startSplits(createTableScanTask.getPartitionedSplitsInfo().getCount());
            }
            this.nodeTaskMap.addTask(internalNode, createTableScanTask);
            this.taskMap.put(internalNode, createTableScanTask);
        }
        Sets.SetView difference = Sets.difference(this.splits, new HashSet(assignments.values()));
        Assertions.assertThat(difference.size()).isEqualTo(140);
        Multimap assignments2 = this.nodeSelector.computeAssignments(difference, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        for (InternalNode internalNode2 : assignments2.keySet()) {
            MockRemoteTaskFactory.MockRemoteTask mockRemoteTask = (MockRemoteTaskFactory.MockRemoteTask) this.taskMap.get(internalNode2);
            mockRemoteTask.addSplits(ImmutableMultimap.builder().putAll(new PlanNodeId("sourceId"), assignments2.get(internalNode2)).build());
            if (internalNode2.equals(node1)) {
                mockRemoteTask.startSplits(mockRemoteTask.getPartitionedSplitsInfo().getCount());
            }
        }
        Sets.SetView difference2 = Sets.difference(difference, new HashSet(assignments2.values()));
        Assertions.assertThat(difference2.size()).isEqualTo(120);
        Assertions.assertThat(assignments2.get(node1).size()).isEqualTo(20);
        Assertions.assertThat(assignments2.containsKey(node2)).isFalse();
        Multimap assignments3 = this.nodeSelector.computeAssignments(difference2, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        for (InternalNode internalNode3 : assignments3.keySet()) {
            MockRemoteTaskFactory.MockRemoteTask mockRemoteTask2 = (MockRemoteTaskFactory.MockRemoteTask) this.taskMap.get(internalNode3);
            mockRemoteTask2.addSplits(ImmutableMultimap.builder().putAll(new PlanNodeId("sourceId"), assignments3.get(internalNode3)).build());
            if (internalNode3.equals(node1)) {
                mockRemoteTask2.startSplits(mockRemoteTask2.getPartitionedSplitsInfo().getCount());
            }
        }
        Assertions.assertThat(Sets.difference(difference2, new HashSet(assignments3.values())).size()).isEqualTo(80);
        Assertions.assertThat(assignments3.get(node1).size()).isEqualTo(40);
        Assertions.assertThat(assignments2.containsKey(node2)).isFalse();
    }

    @Test
    public void testFailover() {
        this.nodeSelector = new UniformNodeSelector(this.nodeManager, this.nodeTaskMap, false, () -> {
            return createNodeMap(TestingHandles.TEST_CATALOG_HANDLE);
        }, 10, 2000L, 1000L, 2000, NodeSchedulerConfig.SplitsBalancingPolicy.STAGE, true, new UniformNodeSelector.QueueSizeAdjuster(1000L, 10000L, new TestingTicker()));
        this.splits.add(new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(node1.getHostAndPort()))));
        Split split = new Split(TestingHandles.TEST_CATALOG_HANDLE, new TestingSplit(true, ImmutableList.of(node1.getHostAndPort())));
        this.splits.add(split);
        Multimap assignments = this.nodeSelector.computeAssignments(this.splits, ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        ArrayListMultimap create = ArrayListMultimap.create();
        create.putAll(node1, this.splits);
        org.assertj.guava.api.Assertions.assertThat(assignments).hasSameEntriesAs(create);
        this.nodeManager.removeNode(node1);
        Assertions.assertThatThrownBy(() -> {
            this.nodeSelector.computeAssignments(this.splits, ImmutableList.copyOf(this.taskMap.values()));
        }).hasMessage("No nodes available to run query");
        Multimap assignments2 = this.nodeSelector.computeAssignments(ImmutableSet.of(split), ImmutableList.copyOf(this.taskMap.values())).getAssignments();
        ArrayListMultimap create2 = ArrayListMultimap.create();
        create2.put(node2, split);
        org.assertj.guava.api.Assertions.assertThat(assignments2).hasSameEntriesAs(create2);
    }

    private NodeMap createNodeMap(CatalogHandle catalogHandle) {
        Set<InternalNode> activeCatalogNodes = this.nodeManager.getActiveCatalogNodes(catalogHandle);
        Set set = (Set) this.nodeManager.getCoordinators().stream().map((v0) -> {
            return v0.getNodeIdentifier();
        }).collect(ImmutableSet.toImmutableSet());
        ImmutableSetMultimap.Builder builder = ImmutableSetMultimap.builder();
        ImmutableSetMultimap.Builder builder2 = ImmutableSetMultimap.builder();
        for (InternalNode internalNode : activeCatalogNodes) {
            try {
                builder.put(internalNode.getHostAndPort(), internalNode);
                builder2.put(internalNode.getInternalAddress(), internalNode);
            } catch (UnknownHostException e) {
            }
        }
        return new NodeMap(builder.build(), builder2.build(), ImmutableSetMultimap.of(), set);
    }
}
