package io.prestosql.operator;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.concurrent.Threads;
import io.airlift.slice.Slices;
import io.airlift.testing.Assertions;
import io.airlift.units.DataSize;
import io.prestosql.ExceededMemoryLimitException;
import io.prestosql.RowPagesBuilder;
import io.prestosql.SessionTestUtils;
import io.prestosql.memory.context.AggregatedMemoryContext;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.MetadataManager;
import io.prestosql.operator.GroupByHashYieldAssertion;
import io.prestosql.operator.HashAggregationOperator;
import io.prestosql.operator.aggregation.InternalAggregationFunction;
import io.prestosql.operator.aggregation.builder.InMemoryHashAggregationBuilder;
import io.prestosql.spi.Page;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.block.BlockBuilderStatus;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.spiller.Spiller;
import io.prestosql.spiller.SpillerFactory;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.gen.JoinCompiler;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.testing.MaterializedResult;
import io.prestosql.testing.TestingTaskContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/prestosql/operator/TestHashAggregationOperator.class */
public class TestHashAggregationOperator {
    private static final Metadata metadata = MetadataManager.createTestMetadataManager();
    private static final InternalAggregationFunction LONG_AVERAGE = metadata.getAggregateFunctionImplementation(metadata.resolveFunction(QualifiedName.of("avg"), TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT})));
    private static final InternalAggregationFunction LONG_SUM = metadata.getAggregateFunctionImplementation(metadata.resolveFunction(QualifiedName.of("sum"), TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT})));
    private static final InternalAggregationFunction COUNT = metadata.getAggregateFunctionImplementation(metadata.resolveFunction(QualifiedName.of("count"), ImmutableList.of()));
    private static final InternalAggregationFunction LONG_MIN = metadata.getAggregateFunctionImplementation(metadata.resolveFunction(QualifiedName.of("min"), TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT})));
    private static final int MAX_BLOCK_SIZE_IN_BYTES = 65536;
    private ExecutorService executor;
    private ScheduledExecutorService scheduledExecutor;
    private JoinCompiler joinCompiler = new JoinCompiler(MetadataManager.createTestMetadataManager());
    private DummySpillerFactory spillerFactory;

    /* loaded from: input_file:io/prestosql/operator/TestHashAggregationOperator$FailingSpillerFactory.class */
    private static class FailingSpillerFactory implements SpillerFactory {
        private FailingSpillerFactory() {
        }

        public Spiller create(List<Type> list, SpillContext spillContext, AggregatedMemoryContext aggregatedMemoryContext) {
            return new Spiller() { // from class: io.prestosql.operator.TestHashAggregationOperator.FailingSpillerFactory.1
                public ListenableFuture<?> spill(Iterator<Page> it) {
                    return Futures.immediateFailedFuture(new IOException("Failed to spill"));
                }

                public List<Iterator<Page>> getSpills() {
                    return ImmutableList.of();
                }

                public void close() {
                }
            };
        }
    }

    @BeforeMethod
    public void setUp() {
        this.executor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed("test-executor-%s"));
        this.scheduledExecutor = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed("test-scheduledExecutor-%s"));
        this.spillerFactory = new DummySpillerFactory();
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider(name = "hashEnabled")
    public static Object[][] hashEnabled() {
        return new Object[]{new Object[]{true}, new Object[]{false}};
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider(name = "hashEnabledAndMemoryLimitForMergeValues")
    public static Object[][] hashEnabledAndMemoryLimitForMergeValuesProvider() {
        return new Object[]{new Object[]{true, true, true, 8, Integer.MAX_VALUE}, new Object[]{true, true, false, 8, Integer.MAX_VALUE}, new Object[]{false, false, false, 0, 0}, new Object[]{false, true, true, 0, 0}, new Object[]{false, true, false, 0, 0}, new Object[]{false, true, true, 8, 0}, new Object[]{false, true, false, 8, 0}, new Object[]{false, true, true, 8, Integer.MAX_VALUE}, new Object[]{false, true, false, 8, Integer.MAX_VALUE}};
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider
    public Object[][] dataType() {
        return new Object[]{new Object[]{VarcharType.VARCHAR}, new Object[]{BigintType.BIGINT}};
    }

    @AfterMethod(alwaysRun = true)
    public void tearDown() {
        this.spillerFactory = null;
        this.executor.shutdownNow();
        this.scheduledExecutor.shutdownNow();
    }

    @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
    public void testHashAggregation(boolean z, boolean z2, boolean z3, long j, long j2) {
        MetadataManager createTestMetadataManager = MetadataManager.createTestMetadataManager();
        InternalAggregationFunction aggregateFunctionImplementation = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("count"), TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR})));
        InternalAggregationFunction aggregateFunctionImplementation2 = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("count"), TypeSignatureProvider.fromTypes(new Type[]{BooleanType.BOOLEAN})));
        InternalAggregationFunction aggregateFunctionImplementation3 = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("max"), TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR})));
        List asList = Ints.asList(new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(z, (List<Integer>) asList, VarcharType.VARCHAR, VarcharType.VARCHAR, VarcharType.VARCHAR, BigintType.BIGINT, BooleanType.BOOLEAN);
        List<Page> build = rowPagesBuilder.addSequencePage(40000, 100, 0, 100000, 0, 500).addSequencePage(40000, 100, 0, 200000, 0, 500).addSequencePage(40000, 100, 0, 300000, 0, 500).build();
        HashAggregationOperator.HashAggregationOperatorFactory hashAggregationOperatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(VarcharType.VARCHAR), asList, ImmutableList.of(), AggregationNode.Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_SUM.bind(ImmutableList.of(3), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(3), Optional.empty()), aggregateFunctionImplementation3.bind(ImmutableList.of(2), Optional.empty()), aggregateFunctionImplementation.bind(ImmutableList.of(0), Optional.empty()), aggregateFunctionImplementation2.bind(ImmutableList.of(4), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), z2, DataSize.succinctBytes(j), DataSize.succinctBytes(j2), this.spillerFactory, this.joinCompiler, false);
        DriverContext createDriverContext = createDriverContext(j);
        MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(createDriverContext.getSession(), new Type[]{VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT, DoubleType.DOUBLE, VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT});
        for (int i = 0; i < 40000; i++) {
            resultBuilder.row(new Object[]{Integer.toString(i), 3L, Long.valueOf(3 * i), Double.valueOf(i), Integer.toString(300000 + i), 3L, 3L});
        }
        MaterializedResult build2 = resultBuilder.build();
        List<Page> pages = OperatorAssertion.toPages(hashAggregationOperatorFactory, createDriverContext, build, z3);
        Assertions.assertGreaterThan(Integer.valueOf(pages.size()), 1, "Expected more than one output page");
        OperatorAssertion.assertPagesEqualIgnoreOrder(createDriverContext, pages, build2, z, Optional.of(Integer.valueOf(asList.size())));
        Assert.assertTrue(z2 == ((this.spillerFactory.getSpillsCount() > 0L ? 1 : (this.spillerFactory.getSpillsCount() == 0L ? 0 : -1)) > 0), String.format("Spill state mismatch. Expected spill: %s, spill count: %s", Boolean.valueOf(z2), Long.valueOf(this.spillerFactory.getSpillsCount())));
    }

    @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
    public void testHashAggregationWithGlobals(boolean z, boolean z2, boolean z3, long j, long j2) {
        MetadataManager createTestMetadataManager = MetadataManager.createTestMetadataManager();
        InternalAggregationFunction aggregateFunctionImplementation = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("count"), TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR})));
        InternalAggregationFunction aggregateFunctionImplementation2 = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("count"), TypeSignatureProvider.fromTypes(new Type[]{BooleanType.BOOLEAN})));
        InternalAggregationFunction aggregateFunctionImplementation3 = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("max"), TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR})));
        Optional of = Optional.of(1);
        List asList = Ints.asList(new int[]{1, 2});
        List asList2 = Ints.asList(new int[]{42, 49});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(z, (List<Integer>) asList, VarcharType.VARCHAR, VarcharType.VARCHAR, VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT, BooleanType.BOOLEAN);
        List<Page> build = rowPagesBuilder.build();
        HashAggregationOperator.HashAggregationOperatorFactory hashAggregationOperatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT), asList, asList2, AggregationNode.Step.SINGLE, true, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_MIN.bind(ImmutableList.of(4), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(4), Optional.empty()), aggregateFunctionImplementation3.bind(ImmutableList.of(2), Optional.empty()), aggregateFunctionImplementation.bind(ImmutableList.of(0), Optional.empty()), aggregateFunctionImplementation2.bind(ImmutableList.of(5), Optional.empty())), rowPagesBuilder.getHashChannel(), of, 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), z2, DataSize.succinctBytes(j), DataSize.succinctBytes(j2), this.spillerFactory, this.joinCompiler, false);
        DriverContext createDriverContext = createDriverContext(j);
        OperatorAssertion.assertOperatorEqualsIgnoreOrder(hashAggregationOperatorFactory, createDriverContext, build, MaterializedResult.resultBuilder(createDriverContext.getSession(), new Type[]{VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT, BigintType.BIGINT, DoubleType.DOUBLE, VarcharType.VARCHAR, BigintType.BIGINT, BigintType.BIGINT}).row(new Object[]{null, 42L, 0L, null, null, null, 0L, 0L}).row(new Object[]{null, 49L, 0L, null, null, null, 0L, 0L}).build(), z, Optional.of(Integer.valueOf(asList.size())), z3);
    }

    @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
    public void testHashAggregationMemoryReservation(boolean z, boolean z2, boolean z3, long j, long j2) {
        MetadataManager createTestMetadataManager = MetadataManager.createTestMetadataManager();
        InternalAggregationFunction aggregateFunctionImplementation = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("array_agg"), TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT})));
        List asList = Ints.asList(new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(z, (List<Integer>) asList, BigintType.BIGINT, BigintType.BIGINT);
        List<Page> build = rowPagesBuilder.addSequencePage(10, 100, 0).addSequencePage(10, 200, 0).addSequencePage(10, 300, 0).build();
        Operator createOperator = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BigintType.BIGINT), asList, ImmutableList.of(), AggregationNode.Step.SINGLE, true, ImmutableList.of(aggregateFunctionImplementation.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), z2, DataSize.succinctBytes(j), DataSize.succinctBytes(j2), this.spillerFactory, this.joinCompiler, false).createOperator(TestingTaskContext.createTaskContext(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION, DataSize.of(10L, DataSize.Unit.MEGABYTE)).addPipelineContext(0, true, true, false).addDriverContext());
        OperatorAssertion.toPages(createOperator, build.iterator(), z3);
        Assert.assertEquals(createOperator.getOperatorContext().getOperatorStats().getUserMemoryReservation().toBytes(), 0L);
    }

    @Test(dataProvider = "hashEnabled", expectedExceptions = {ExceededMemoryLimitException.class}, expectedExceptionsMessageRegExp = "Query exceeded per-node user memory limit of 10B.*")
    public void testMemoryLimit(boolean z) {
        MetadataManager createTestMetadataManager = MetadataManager.createTestMetadataManager();
        InternalAggregationFunction aggregateFunctionImplementation = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("max"), TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR})));
        List asList = Ints.asList(new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(z, (List<Integer>) asList, VarcharType.VARCHAR, BigintType.BIGINT, VarcharType.VARCHAR, BigintType.BIGINT);
        OperatorAssertion.toPages((OperatorFactory) new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BigintType.BIGINT), asList, ImmutableList.of(), AggregationNode.Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_MIN.bind(ImmutableList.of(3), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(3), Optional.empty()), aggregateFunctionImplementation.bind(ImmutableList.of(2), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), this.joinCompiler, false), TestingTaskContext.createTaskContext(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION, DataSize.ofBytes(10L)).addPipelineContext(0, true, true, false).addDriverContext(), rowPagesBuilder.addSequencePage(10, 100, 0, 100, 0).addSequencePage(10, 100, 0, 200, 0).addSequencePage(10, 100, 0, 300, 0).build());
    }

    @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
    public void testHashBuilderResize(boolean z, boolean z2, boolean z3, long j, long j2) {
        BlockBuilder createBlockBuilder = VarcharType.VARCHAR.createBlockBuilder((BlockBuilderStatus) null, 1, MAX_BLOCK_SIZE_IN_BYTES);
        VarcharType.VARCHAR.writeSlice(createBlockBuilder, Slices.allocate(200000));
        createBlockBuilder.build();
        List asList = Ints.asList(new int[]{0});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(z, (List<Integer>) asList, VarcharType.VARCHAR);
        OperatorAssertion.toPages(new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(VarcharType.VARCHAR), asList, ImmutableList.of(), AggregationNode.Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), z2, DataSize.succinctBytes(j), DataSize.succinctBytes(j2), this.spillerFactory, this.joinCompiler, false), createDriverContext(j), rowPagesBuilder.addSequencePage(10, 100).addBlocksPage(createBlockBuilder.build()).addSequencePage(10, 100).build(), z3);
    }

    @Test(dataProvider = "dataType")
    public void testMemoryReservationYield(Type type) {
        GroupByHashYieldAssertion.GroupByHashYieldResult finishOperatorWithYieldingGroupByHash = GroupByHashYieldAssertion.finishOperatorWithYieldingGroupByHash(GroupByHashYieldAssertion.createPagesWithDistinctHashKeys(type, 6000, 600), type, new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(type), ImmutableList.of(0), ImmutableList.of(), AggregationNode.Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty())), Optional.of(1), Optional.empty(), 1, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), this.joinCompiler, false), this::getHashCapacity, 1400000L);
        Assertions.assertGreaterThan(Integer.valueOf(finishOperatorWithYieldingGroupByHash.getYieldCount()), 5);
        Assertions.assertGreaterThan(Long.valueOf(finishOperatorWithYieldingGroupByHash.getMaxReservedBytes()), 20971520L);
        int i = 0;
        for (Page page : finishOperatorWithYieldingGroupByHash.getOutput()) {
            Assert.assertEquals(page.getChannelCount(), 3);
            for (int i2 = 0; i2 < page.getPositionCount(); i2++) {
                Assert.assertEquals(page.getBlock(2).getLong(i2, 0), 1L);
                i++;
            }
        }
        Assert.assertEquals(i, 3600000);
    }

    @Test(dataProvider = "hashEnabled", expectedExceptions = {ExceededMemoryLimitException.class}, expectedExceptionsMessageRegExp = "Query exceeded per-node user memory limit of 3MB.*")
    public void testHashBuilderResizeLimit(boolean z) {
        BlockBuilder createBlockBuilder = VarcharType.VARCHAR.createBlockBuilder((BlockBuilderStatus) null, 1, MAX_BLOCK_SIZE_IN_BYTES);
        VarcharType.VARCHAR.writeSlice(createBlockBuilder, Slices.allocate(5000000));
        createBlockBuilder.build();
        List asList = Ints.asList(new int[]{0});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(z, (List<Integer>) asList, VarcharType.VARCHAR);
        OperatorAssertion.toPages((OperatorFactory) new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(VarcharType.VARCHAR), asList, ImmutableList.of(), AggregationNode.Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), this.joinCompiler, false), TestingTaskContext.createTaskContext(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION, DataSize.of(3L, DataSize.Unit.MEGABYTE)).addPipelineContext(0, true, true, false).addDriverContext(), rowPagesBuilder.addSequencePage(10, 100).addBlocksPage(createBlockBuilder.build()).addSequencePage(10, 100).build());
    }

    @Test(dataProvider = "hashEnabled")
    public void testMultiSliceAggregationOutput(boolean z) {
        List asList = Ints.asList(new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(z, (List<Integer>) asList, BigintType.BIGINT, BigintType.BIGINT);
        Assert.assertEquals(OperatorAssertion.toPages((OperatorFactory) new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BigintType.BIGINT), asList, ImmutableList.of(), AggregationNode.Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(1), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), this.joinCompiler, false), createDriverContext(), rowPagesBuilder.addSequencePage((int) (1572864.0d / 32), 0, 0).build()).size(), 2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test(dataProvider = "hashEnabled")
    public void testMultiplePartialFlushes(boolean z) throws Exception {
        List asList = Ints.asList(new int[]{0});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(z, (List<Integer>) asList, BigintType.BIGINT);
        List<Page> build = rowPagesBuilder.addSequencePage(500, 0).addSequencePage(500, 500).addSequencePage(500, 1000).addSequencePage(500, 1500).build();
        HashAggregationOperator.HashAggregationOperatorFactory hashAggregationOperatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BigintType.BIGINT), asList, ImmutableList.of(), AggregationNode.Step.PARTIAL, ImmutableList.of(LONG_MIN.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(1L, DataSize.Unit.KILOBYTE)), this.joinCompiler, true);
        DriverContext createDriverContext = createDriverContext(1024L);
        Operator createOperator = hashAggregationOperatorFactory.createOperator(createDriverContext);
        try {
            MaterializedResult build2 = MaterializedResult.resultBuilder(createDriverContext.getSession(), new Type[]{BigintType.BIGINT, BigintType.BIGINT}).pages(RowPagesBuilder.rowPagesBuilder(BigintType.BIGINT, BigintType.BIGINT).addSequencePage(2000, 0, 0).build()).build();
            Iterator<Page> it = build.iterator();
            while (createOperator.needsInput() && it.hasNext()) {
                createOperator.addInput(it.next());
            }
            org.assertj.core.api.Assertions.assertThat(createDriverContext.getSystemMemoryUsage()).isGreaterThan(0L);
            Assert.assertEquals(createDriverContext.getMemoryUsage(), 0L);
            List arrayList = new ArrayList();
            while (true) {
                Page output = createOperator.getOutput();
                if (output == null) {
                    break;
                } else {
                    arrayList.add(output);
                }
            }
            Assert.assertTrue(!arrayList.isEmpty());
            Assert.assertTrue(createOperator.needsInput());
            arrayList.addAll(OperatorAssertion.toPages(createOperator, it));
            if (z) {
                arrayList = OperatorAssertion.dropChannel(arrayList, ImmutableList.of(1));
            }
            MaterializedResult materializedResult = OperatorAssertion.toMaterializedResult(createOperator.getOperatorContext().getSession(), build2.getTypes(), arrayList);
            Assert.assertEquals(materializedResult.getTypes(), build2.getTypes());
            Assertions.assertEqualsIgnoreOrder(materializedResult.getMaterializedRows(), build2.getMaterializedRows());
            if (createOperator != null) {
                createOperator.close();
            }
            Assert.assertEquals(createDriverContext.getSystemMemoryUsage(), 0L);
            Assert.assertEquals(createDriverContext.getMemoryUsage(), 0L);
        } catch (Throwable th) {
            if (createOperator != null) {
                try {
                    createOperator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testMergeWithMemorySpill() {
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(BigintType.BIGINT);
        List<Page> build = rowPagesBuilder.addSequencePage(150000, 0).addSequencePage(10, 150000).build();
        HashAggregationOperator.HashAggregationOperatorFactory hashAggregationOperatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BigintType.BIGINT), ImmutableList.of(0), ImmutableList.of(), AggregationNode.Step.SINGLE, false, ImmutableList.of(LONG_MIN.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 1, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), true, DataSize.ofBytes(150000), DataSize.succinctBytes(2147483647L), this.spillerFactory, this.joinCompiler, false);
        DriverContext createDriverContext = createDriverContext(150000);
        MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(createDriverContext.getSession(), new Type[]{BigintType.BIGINT, BigintType.BIGINT});
        for (int i = 0; i < 150000 + 10; i++) {
            resultBuilder.row(new Object[]{Long.valueOf(i), Long.valueOf(i)});
        }
        OperatorAssertion.assertOperatorEqualsIgnoreOrder(hashAggregationOperatorFactory, createDriverContext, build, resultBuilder.build());
    }

    @Test
    public void testSpillerFailure() {
        MetadataManager createTestMetadataManager = MetadataManager.createTestMetadataManager();
        InternalAggregationFunction aggregateFunctionImplementation = createTestMetadataManager.getAggregateFunctionImplementation(createTestMetadataManager.resolveFunction(QualifiedName.of("max"), TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR})));
        List asList = Ints.asList(new int[]{1});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>) asList, (Iterable<Type>) ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT, VarcharType.VARCHAR, BigintType.BIGINT));
        try {
            OperatorAssertion.toPages((OperatorFactory) new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BigintType.BIGINT), asList, ImmutableList.of(), AggregationNode.Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_MIN.bind(ImmutableList.of(3), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(3), Optional.empty()), aggregateFunctionImplementation.bind(ImmutableList.of(2), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), true, DataSize.succinctBytes(8L), DataSize.succinctBytes(2147483647L), new FailingSpillerFactory(), this.joinCompiler, false), TestingTaskContext.builder(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION).setQueryMaxMemory(DataSize.valueOf("7MB")).setMemoryPoolSize(DataSize.valueOf("1GB")).build().addPipelineContext(0, true, true, false).addDriverContext(), rowPagesBuilder.addSequencePage(10, 100, 0, 100, 0).addSequencePage(10, 100, 0, 200, 0).addSequencePage(10, 100, 0, 300, 0).build());
            Assert.fail("An exception was expected");
        } catch (RuntimeException e) {
            if (Strings.nullToEmpty(e.getMessage()).matches(".* Failed to spill")) {
                return;
            }
            Assert.fail("Exception other than expected was thrown", e);
        }
    }

    @Test
    public void testMemoryTracking() throws Exception {
        testMemoryTracking(false);
        testMemoryTracking(true);
    }

    private void testMemoryTracking(boolean z) throws Exception {
        List asList = Ints.asList(new int[]{0});
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(false, (List<Integer>) asList, BigintType.BIGINT);
        Page page = (Page) Iterables.getOnlyElement(rowPagesBuilder.addSequencePage(500, 0).build());
        HashAggregationOperator.HashAggregationOperatorFactory hashAggregationOperatorFactory = new HashAggregationOperator.HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BigintType.BIGINT), asList, ImmutableList.of(), AggregationNode.Step.SINGLE, ImmutableList.of(LONG_MIN.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100000, Optional.of(DataSize.of(16L, DataSize.Unit.MEGABYTE)), this.joinCompiler, z);
        DriverContext createDriverContext = createDriverContext(1024L);
        Operator createOperator = hashAggregationOperatorFactory.createOperator(createDriverContext);
        try {
            Assert.assertTrue(createOperator.needsInput());
            createOperator.addInput(page);
            if (z) {
                org.assertj.core.api.Assertions.assertThat(createDriverContext.getSystemMemoryUsage()).isGreaterThan(0L);
                Assert.assertEquals(createDriverContext.getMemoryUsage(), 0L);
            } else {
                Assert.assertEquals(createDriverContext.getSystemMemoryUsage(), 0L);
                org.assertj.core.api.Assertions.assertThat(createDriverContext.getMemoryUsage()).isGreaterThan(0L);
            }
            OperatorAssertion.toPages(createOperator, (Iterator<Page>) Collections.emptyIterator());
            if (createOperator != null) {
                createOperator.close();
            }
            Assert.assertEquals(createDriverContext.getSystemMemoryUsage(), 0L);
            Assert.assertEquals(createDriverContext.getMemoryUsage(), 0L);
        } catch (Throwable th) {
            if (createOperator != null) {
                try {
                    createOperator.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private DriverContext createDriverContext() {
        return createDriverContext(2147483647L);
    }

    private DriverContext createDriverContext(long j) {
        return TestingTaskContext.builder(this.executor, this.scheduledExecutor, SessionTestUtils.TEST_SESSION).setMemoryPoolSize(DataSize.succinctBytes(j)).build().addPipelineContext(0, true, true, false).addDriverContext();
    }

    private int getHashCapacity(Operator operator) {
        Assert.assertTrue(operator instanceof HashAggregationOperator);
        InMemoryHashAggregationBuilder aggregationBuilder = ((HashAggregationOperator) operator).getAggregationBuilder();
        if (aggregationBuilder == null) {
            return 0;
        }
        Assert.assertTrue(aggregationBuilder instanceof InMemoryHashAggregationBuilder);
        return aggregationBuilder.getCapacity();
    }
}
