package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.SessionTestUtils;
import io.trino.jmh.Benchmarks;
import io.trino.metadata.Metadata;
import io.trino.metadata.MetadataManager;
import io.trino.operator.GroupByIdBlock;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.spi.type.UnscaledDecimal128Arithmetic;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.TestTableScanNodePartitioning;
import io.trino.sql.tree.QualifiedName;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.options.WarmupMode;
import org.testng.Assert;
import org.testng.annotations.Test;

@Warmup(iterations = TestTableScanNodePartitioning.BUCKET_COUNT)
@State(Scope.Thread)
@Measurement(iterations = TestTableScanNodePartitioning.BUCKET_COUNT)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(3)
@BenchmarkMode({Mode.AverageTime})
/* loaded from: input_file:io/trino/operator/aggregation/BenchmarkDecimalAggregation.class */
public class BenchmarkDecimalAggregation {
    private static final int ELEMENT_COUNT = 1000000;

    @State(Scope.Thread)
    /* loaded from: input_file:io/trino/operator/aggregation/BenchmarkDecimalAggregation$BenchmarkData.class */
    public static class BenchmarkData {

        @Param({"SHORT", "LONG"})
        private String type = "SHORT";

        @Param({"avg", "sum"})
        private String function = "avg";

        @Param({"10", "1000"})
        private int groupCount = 10;
        private AccumulatorFactory factory;
        private GroupedAccumulator accumulator;
        private GroupByIdBlock groupIds;
        private Page values;
        private Block intermediateValues;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:io/trino/operator/aggregation/BenchmarkDecimalAggregation$BenchmarkData$ValueWriter.class */
        public interface ValueWriter {
            void write(BlockBuilder blockBuilder, int i);
        }

        @Setup
        public void setup() {
            MetadataManager createTestMetadataManager = MetadataManager.createTestMetadataManager();
            String str = this.type;
            boolean z = -1;
            switch (str.hashCode()) {
                case 2342524:
                    if (str.equals("LONG")) {
                        z = true;
                        break;
                    }
                    break;
                case 78875740:
                    if (str.equals("SHORT")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    DecimalType createDecimalType = DecimalType.createDecimalType(14, 3);
                    Objects.requireNonNull(createDecimalType);
                    this.values = createValues(createTestMetadataManager, createDecimalType, (v1, v2) -> {
                        r4.writeLong(v1, v2);
                    });
                    break;
                case true:
                    DecimalType createDecimalType2 = DecimalType.createDecimalType(30, 10);
                    this.values = createValues(createTestMetadataManager, createDecimalType2, (blockBuilder, i) -> {
                        createDecimalType2.writeSlice(blockBuilder, UnscaledDecimal128Arithmetic.unscaledDecimal(i));
                    });
                    break;
            }
            BlockBuilder createBlockBuilder = BigintType.BIGINT.createBlockBuilder((BlockBuilderStatus) null, BenchmarkDecimalAggregation.ELEMENT_COUNT);
            for (int i2 = 0; i2 < BenchmarkDecimalAggregation.ELEMENT_COUNT; i2++) {
                BigintType.BIGINT.writeLong(createBlockBuilder, ThreadLocalRandom.current().nextLong(this.groupCount));
            }
            this.groupIds = new GroupByIdBlock(this.groupCount, createBlockBuilder.build());
            this.intermediateValues = createIntermediateValues(this.factory.createGroupedAccumulator(), this.groupIds, this.values);
        }

        private Block createIntermediateValues(GroupedAccumulator groupedAccumulator, GroupByIdBlock groupByIdBlock, Page page) {
            groupedAccumulator.addInput(groupByIdBlock, page);
            BlockBuilder createBlockBuilder = groupedAccumulator.getIntermediateType().createBlockBuilder((BlockBuilderStatus) null, Math.toIntExact(groupByIdBlock.getGroupCount()));
            for (int i = 0; i < groupByIdBlock.getGroupCount(); i++) {
                groupedAccumulator.evaluateIntermediate(i, createBlockBuilder);
            }
            return createBlockBuilder.build();
        }

        /* JADX WARN: Multi-variable type inference failed */
        private Page createValues(Metadata metadata, DecimalType decimalType, ValueWriter valueWriter) {
            this.factory = metadata.getAggregateFunctionImplementation(metadata.resolveFunction(SessionTestUtils.TEST_SESSION, QualifiedName.of(this.function), TypeSignatureProvider.fromTypes(new Type[]{decimalType}))).bind(ImmutableList.of(0), Optional.empty());
            this.accumulator = this.factory.createGroupedAccumulator();
            BlockBuilder createBlockBuilder = decimalType.createBlockBuilder((BlockBuilderStatus) null, BenchmarkDecimalAggregation.ELEMENT_COUNT);
            for (int i = 0; i < BenchmarkDecimalAggregation.ELEMENT_COUNT; i++) {
                valueWriter.write(createBlockBuilder, i);
            }
            return new Page(new Block[]{createBlockBuilder.build()});
        }

        public AccumulatorFactory getAccumulatorFactory() {
            return this.factory;
        }

        public GroupedAccumulator getAccumulator() {
            return this.accumulator;
        }

        public Page getValues() {
            return this.values;
        }

        public GroupByIdBlock getGroupIds() {
            return this.groupIds;
        }

        public int getGroupCount() {
            return this.groupCount;
        }

        public Block getIntermediateValues() {
            return this.intermediateValues;
        }
    }

    @Benchmark
    @OperationsPerInvocation(ELEMENT_COUNT)
    public GroupedAccumulator benchmark(BenchmarkData benchmarkData) {
        GroupedAccumulator accumulator = benchmarkData.getAccumulator();
        accumulator.addInput(benchmarkData.getGroupIds(), benchmarkData.getValues());
        return accumulator;
    }

    @Benchmark
    @OperationsPerInvocation(ELEMENT_COUNT)
    public Block benchmarkEvaluateIntermediate(BenchmarkData benchmarkData) {
        GroupedAccumulator createGroupedAccumulator = benchmarkData.getAccumulatorFactory().createGroupedAccumulator();
        createGroupedAccumulator.addInput(benchmarkData.getGroupIds(), benchmarkData.getValues());
        BlockBuilder createBlockBuilder = createGroupedAccumulator.getIntermediateType().createBlockBuilder((BlockBuilderStatus) null, benchmarkData.getGroupCount());
        for (int i = 0; i < benchmarkData.getGroupCount(); i++) {
            createGroupedAccumulator.evaluateIntermediate(i, createBlockBuilder);
        }
        return createBlockBuilder.build();
    }

    @Benchmark
    public Block benchmarkEvaluateFinal(BenchmarkData benchmarkData) {
        GroupedAccumulator createGroupedIntermediateAccumulator = benchmarkData.getAccumulatorFactory().createGroupedIntermediateAccumulator();
        createGroupedIntermediateAccumulator.addIntermediate(benchmarkData.getGroupIds(), benchmarkData.getIntermediateValues());
        createGroupedIntermediateAccumulator.addIntermediate(benchmarkData.getGroupIds(), benchmarkData.getIntermediateValues());
        BlockBuilder createBlockBuilder = createGroupedIntermediateAccumulator.getFinalType().createBlockBuilder((BlockBuilderStatus) null, benchmarkData.getGroupCount());
        for (int i = 0; i < benchmarkData.getGroupCount(); i++) {
            createGroupedIntermediateAccumulator.evaluateFinal(i, createBlockBuilder);
        }
        return createBlockBuilder.build();
    }

    @Test
    public void verify() {
        BenchmarkData benchmarkData = new BenchmarkData();
        benchmarkData.setup();
        Assert.assertEquals(benchmarkData.groupIds.getPositionCount(), benchmarkData.getValues().getPositionCount());
        new BenchmarkDecimalAggregation().benchmark(benchmarkData);
    }

    public static void main(String[] strArr) throws Exception {
        new BenchmarkDecimalAggregation().verify();
        Benchmarks.benchmark(BenchmarkDecimalAggregation.class, WarmupMode.BULK).run();
    }
}
