package io.prestosql.memory;

import io.airlift.concurrent.Threads;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.units.DataSize;
import io.prestosql.ExceededMemoryLimitException;
import io.prestosql.execution.TaskId;
import io.prestosql.execution.TaskStateMachine;
import io.prestosql.memory.context.LocalMemoryContext;
import io.prestosql.memory.context.MemoryTrackingContext;
import io.prestosql.operator.DriverContext;
import io.prestosql.operator.DriverStats;
import io.prestosql.operator.OperatorContext;
import io.prestosql.operator.OperatorStats;
import io.prestosql.operator.PipelineContext;
import io.prestosql.operator.PipelineStats;
import io.prestosql.operator.TaskContext;
import io.prestosql.operator.TaskStats;
import io.prestosql.spi.QueryId;
import io.prestosql.spi.memory.MemoryPoolId;
import io.prestosql.spiller.SpillSpaceTracker;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.testing.TestingSession;
import java.util.OptionalInt;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/prestosql/memory/TestMemoryTracking.class */
public class TestMemoryTracking {
    private static final DataSize queryMaxMemory = DataSize.of(1, DataSize.Unit.GIGABYTE);
    private static final DataSize queryMaxTotalMemory = DataSize.of(1, DataSize.Unit.GIGABYTE);
    private static final DataSize memoryPoolSize = DataSize.of(1, DataSize.Unit.GIGABYTE);
    private static final DataSize maxSpillSize = DataSize.of(1, DataSize.Unit.GIGABYTE);
    private static final DataSize queryMaxSpillSize = DataSize.of(1, DataSize.Unit.GIGABYTE);
    private static final SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxSpillSize);
    private QueryContext queryContext;
    private TaskContext taskContext;
    private PipelineContext pipelineContext;
    private DriverContext driverContext;
    private OperatorContext operatorContext;
    private MemoryPool memoryPool;
    private ExecutorService notificationExecutor;
    private ScheduledExecutorService yieldExecutor;

    @BeforeClass
    public void setUp() {
        this.notificationExecutor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed("local-query-runner-executor-%s"));
        this.yieldExecutor = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed("local-query-runner-scheduler-%s"));
    }

    @AfterClass(alwaysRun = true)
    public void tearDown() {
        this.notificationExecutor.shutdownNow();
        this.yieldExecutor.shutdownNow();
        this.queryContext = null;
        this.taskContext = null;
        this.pipelineContext = null;
        this.driverContext = null;
        this.operatorContext = null;
        this.memoryPool = null;
    }

    @BeforeMethod
    public void setUpTest() {
        this.memoryPool = new MemoryPool(new MemoryPoolId("test"), memoryPoolSize);
        this.queryContext = new QueryContext(new QueryId("test_query"), queryMaxMemory, queryMaxTotalMemory, this.memoryPool, new TestingGcMonitor(), this.notificationExecutor, this.yieldExecutor, queryMaxSpillSize, spillSpaceTracker);
        this.taskContext = this.queryContext.addTaskContext(new TaskStateMachine(new TaskId("query", 0, 0), this.notificationExecutor), TestingSession.testSessionBuilder().build(), () -> {
        }, true, true, OptionalInt.empty());
        this.pipelineContext = this.taskContext.addPipelineContext(0, true, true, false);
        this.driverContext = this.pipelineContext.addDriverContext();
        this.operatorContext = this.driverContext.addOperatorContext(1, new PlanNodeId("a"), "test-operator");
    }

    @Test
    public void testOperatorAllocations() {
        MemoryTrackingContext operatorMemoryContext = this.operatorContext.getOperatorMemoryContext();
        LocalMemoryContext newLocalSystemMemoryContext = this.operatorContext.newLocalSystemMemoryContext("test");
        LocalMemoryContext localUserMemoryContext = this.operatorContext.localUserMemoryContext();
        LocalMemoryContext localRevocableMemoryContext = this.operatorContext.localRevocableMemoryContext();
        localUserMemoryContext.setBytes(100L);
        assertOperatorMemoryAllocations(operatorMemoryContext, 100L, 0L, 0L);
        newLocalSystemMemoryContext.setBytes(1000000L);
        assertOperatorMemoryAllocations(operatorMemoryContext, 100L, 1000000L, 0L);
        newLocalSystemMemoryContext.setBytes(2000000L);
        assertOperatorMemoryAllocations(operatorMemoryContext, 100L, 2000000L, 0L);
        localUserMemoryContext.setBytes(500L);
        assertOperatorMemoryAllocations(operatorMemoryContext, 500L, 2000000L, 0L);
        localUserMemoryContext.setBytes(localUserMemoryContext.getBytes() - 500);
        assertOperatorMemoryAllocations(operatorMemoryContext, 0L, 2000000L, 0L);
        localRevocableMemoryContext.setBytes(300L);
        assertOperatorMemoryAllocations(operatorMemoryContext, 0L, 2000000L, 300L);
        Assertions.assertThatThrownBy(() -> {
            localUserMemoryContext.setBytes(localUserMemoryContext.getBytes() - 500);
        }).isInstanceOf(IllegalArgumentException.class).hasMessage("bytes cannot be negative");
        this.operatorContext.destroy();
        assertOperatorMemoryAllocations(operatorMemoryContext, 0L, 0L, 0L);
    }

    @Test
    public void testLocalTotalMemoryLimitExceeded() {
        LocalMemoryContext newLocalSystemMemoryContext = this.operatorContext.newLocalSystemMemoryContext("test");
        newLocalSystemMemoryContext.setBytes(100L);
        assertOperatorMemoryAllocations(this.operatorContext.getOperatorMemoryContext(), 0L, 100L, 0L);
        newLocalSystemMemoryContext.setBytes(queryMaxTotalMemory.toBytes());
        assertOperatorMemoryAllocations(this.operatorContext.getOperatorMemoryContext(), 0L, queryMaxTotalMemory.toBytes(), 0L);
        try {
            newLocalSystemMemoryContext.setBytes(queryMaxTotalMemory.toBytes() + 1);
            Assert.fail("allocation should hit the per-node total memory limit");
        } catch (ExceededMemoryLimitException e) {
            Assert.assertEquals(e.getMessage(), String.format("Query exceeded per-node total memory limit of %1$s [Allocated: %1$s, Delta: 1B, Top Consumers: {test=%1$s}]", queryMaxTotalMemory));
        }
    }

    @Test
    public void testLocalSystemAllocations() {
        LocalMemoryContext localSystemMemoryContext = this.pipelineContext.localSystemMemoryContext();
        localSystemMemoryContext.setBytes(1000000L);
        assertLocalMemoryAllocations(this.pipelineContext.getPipelineMemoryContext(), 1000000L, 0L, 1000000L);
        LocalMemoryContext localSystemMemoryContext2 = this.taskContext.localSystemMemoryContext();
        localSystemMemoryContext2.setBytes(10000000L);
        assertLocalMemoryAllocations(this.taskContext.getTaskMemoryContext(), 1000000 + 10000000, 0L, 10000000L);
        Assert.assertEquals(this.pipelineContext.getPipelineStats().getSystemMemoryReservation().toBytes(), 1000000L, "task level allocations should not be visible at the pipeline level");
        localSystemMemoryContext.setBytes(localSystemMemoryContext.getBytes() - 1000000);
        assertLocalMemoryAllocations(this.pipelineContext.getPipelineMemoryContext(), 10000000L, 0L, 0L);
        localSystemMemoryContext2.setBytes(localSystemMemoryContext2.getBytes() - 10000000);
        assertLocalMemoryAllocations(this.taskContext.getTaskMemoryContext(), 0L, 0L, 0L);
    }

    @Test
    public void testStats() {
        LocalMemoryContext newLocalSystemMemoryContext = this.operatorContext.newLocalSystemMemoryContext("test");
        LocalMemoryContext localUserMemoryContext = this.operatorContext.localUserMemoryContext();
        localUserMemoryContext.setBytes(100000000L);
        newLocalSystemMemoryContext.setBytes(200000000L);
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 0L, 200000000L);
        localUserMemoryContext.setBytes(600000000L);
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 600000000L, 0L, 200000000L);
        localUserMemoryContext.setBytes(localUserMemoryContext.getBytes() - 300000000);
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 300000000L, 0L, 200000000L);
        localUserMemoryContext.setBytes(localUserMemoryContext.getBytes() - 300000000);
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 0L, 0L, 200000000L);
        this.operatorContext.destroy();
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 0L, 0L, 0L);
    }

    @Test
    public void testRevocableMemoryAllocations() {
        LocalMemoryContext newLocalSystemMemoryContext = this.operatorContext.newLocalSystemMemoryContext("test");
        LocalMemoryContext localUserMemoryContext = this.operatorContext.localUserMemoryContext();
        LocalMemoryContext localRevocableMemoryContext = this.operatorContext.localRevocableMemoryContext();
        localRevocableMemoryContext.setBytes(100000000L);
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 0L, 100000000L, 0L);
        localUserMemoryContext.setBytes(100000000L);
        newLocalSystemMemoryContext.setBytes(100000000L);
        localRevocableMemoryContext.setBytes(200000000L);
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 200000000L, 100000000L);
    }

    @Test
    public void testTrySetBytes() {
        LocalMemoryContext localUserMemoryContext = this.operatorContext.localUserMemoryContext();
        Assert.assertTrue(localUserMemoryContext.trySetBytes(100000000L));
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 0L, 0L);
        Assert.assertTrue(localUserMemoryContext.trySetBytes(200000000L));
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 200000000L, 0L, 0L);
        Assert.assertTrue(localUserMemoryContext.trySetBytes(100000000L));
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 0L, 0L);
        Assert.assertFalse(localUserMemoryContext.trySetBytes(this.memoryPool.getMaxBytes() + 1));
        assertStats(this.operatorContext.getOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 0L, 0L);
    }

    @Test
    public void testTrySetZeroBytesFullPool() {
        LocalMemoryContext localUserMemoryContext = this.operatorContext.localUserMemoryContext();
        this.memoryPool.reserve(new QueryId("test_query"), "test", this.memoryPool.getFreeBytes());
        Assert.assertTrue(localUserMemoryContext.trySetBytes(localUserMemoryContext.getBytes()));
    }

    @Test
    public void testDestroy() {
        LocalMemoryContext newLocalSystemMemoryContext = this.operatorContext.newLocalSystemMemoryContext("test");
        LocalMemoryContext localUserMemoryContext = this.operatorContext.localUserMemoryContext();
        LocalMemoryContext localRevocableMemoryContext = this.operatorContext.localRevocableMemoryContext();
        newLocalSystemMemoryContext.setBytes(100000L);
        localRevocableMemoryContext.setBytes(200000L);
        localUserMemoryContext.setBytes(400000L);
        Assert.assertEquals(this.operatorContext.getOperatorMemoryContext().getSystemMemory(), 100000L);
        Assert.assertEquals(this.operatorContext.getOperatorMemoryContext().getUserMemory(), 400000L);
        this.operatorContext.destroy();
        assertOperatorMemoryAllocations(this.operatorContext.getOperatorMemoryContext(), 0L, 0L, 0L);
    }

    private void assertStats(OperatorStats operatorStats, DriverStats driverStats, PipelineStats pipelineStats, TaskStats taskStats, long j, long j2, long j3) {
        Assert.assertEquals(operatorStats.getUserMemoryReservation().toBytes(), j);
        Assert.assertEquals(driverStats.getUserMemoryReservation().toBytes(), j);
        Assert.assertEquals(pipelineStats.getUserMemoryReservation().toBytes(), j);
        Assert.assertEquals(taskStats.getUserMemoryReservation().toBytes(), j);
        Assert.assertEquals(operatorStats.getSystemMemoryReservation().toBytes(), j3);
        Assert.assertEquals(driverStats.getSystemMemoryReservation().toBytes(), j3);
        Assert.assertEquals(pipelineStats.getSystemMemoryReservation().toBytes(), j3);
        Assert.assertEquals(taskStats.getSystemMemoryReservation().toBytes(), j3);
        Assert.assertEquals(operatorStats.getRevocableMemoryReservation().toBytes(), j2);
        Assert.assertEquals(driverStats.getRevocableMemoryReservation().toBytes(), j2);
        Assert.assertEquals(pipelineStats.getRevocableMemoryReservation().toBytes(), j2);
        Assert.assertEquals(taskStats.getRevocableMemoryReservation().toBytes(), j2);
    }

    private void assertOperatorMemoryAllocations(MemoryTrackingContext memoryTrackingContext, long j, long j2, long j3) {
        Assert.assertEquals(memoryTrackingContext.getUserMemory(), j, "User memory verification failed");
        Assert.assertEquals(this.memoryPool.getReservedBytes(), j + j2, "Memory pool verification failed");
        Assert.assertEquals(memoryTrackingContext.getSystemMemory(), j2, "System memory verification failed");
        Assert.assertEquals(memoryTrackingContext.getRevocableMemory(), j3, "Revocable memory verification failed");
    }

    private void assertLocalMemoryAllocations(MemoryTrackingContext memoryTrackingContext, long j, long j2, long j3) {
        Assert.assertEquals(memoryTrackingContext.getUserMemory(), j2, "User memory verification failed");
        Assert.assertEquals(this.memoryPool.getReservedBytes(), j, "Memory pool verification failed");
        Assert.assertEquals(memoryTrackingContext.localSystemMemoryContext().getBytes(), j3, "Local system memory verification failed");
    }
}
