/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.DecimalAverageAggregation;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.Int128ArrayBlock;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestDecimalAverageAggregation {
    private static final BigInteger TWO = new BigInteger("2");
    private static final BigInteger ONE_HUNDRED = new BigInteger("100");
    private static final BigInteger TWO_HUNDRED = new BigInteger("200");
    private static final DecimalType TYPE = DecimalType.createDecimalType((int)38, (int)0);

    @Test
    public void testOverflow() {
        LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126));
        Assertions.assertThat((long)state.getLong()).isEqualTo(1L);
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(0L);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Int128.valueOf((BigInteger)TWO.pow(126)));
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126));
        Assertions.assertThat((long)state.getLong()).isEqualTo(2L);
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(1L);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Int128.valueOf((long)Long.MIN_VALUE, (long)0L));
        this.assertAverageEquals(state, TWO.pow(126));
    }

    @Test
    public void testUnderflow() {
        LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(state, Decimals.MIN_UNSCALED_DECIMAL.toBigInteger());
        Assertions.assertThat((long)state.getLong()).isEqualTo(1L);
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(0L);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Decimals.MIN_UNSCALED_DECIMAL);
        TestDecimalAverageAggregation.addToState(state, Decimals.MIN_UNSCALED_DECIMAL.toBigInteger());
        Assertions.assertThat((long)state.getLong()).isEqualTo(2L);
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(-1L);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Int128.valueOf((long)7604722348854507275L, (long)-1374799102801346558L));
        this.assertAverageEquals(state, Decimals.MIN_UNSCALED_DECIMAL.toBigInteger());
    }

    @Test
    public void testUnderflowAfterOverflow() {
        LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126));
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126));
        TestDecimalAverageAggregation.addToState(state, TWO.pow(125));
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(1L);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Int128.valueOf((long)-6917529027641081856L, (long)0L));
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126).negate());
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126).negate());
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126).negate());
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(0L);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Int128.valueOf((BigInteger)TWO.pow(125).negate()));
        this.assertAverageEquals(state, TWO.pow(125).negate().divide(BigInteger.valueOf(6L)));
    }

    @Test
    public void testCombineOverflow() {
        LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126));
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126));
        LongDecimalWithOverflowAndLongState otherState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(otherState, TWO.pow(126));
        TestDecimalAverageAggregation.addToState(otherState, TWO.pow(126));
        DecimalAverageAggregation.combine((LongDecimalWithOverflowAndLongState)state, (LongDecimalWithOverflowAndLongState)otherState);
        Assertions.assertThat((long)state.getLong()).isEqualTo(4L);
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(1L);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Int128.ZERO);
        BigInteger expectedAverage = BigInteger.ZERO.add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(126)).divide(BigInteger.valueOf(4L));
        this.assertAverageEquals(state, expectedAverage);
    }

    @Test
    public void testCombineUnderflow() {
        LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(state, TWO.pow(125).negate());
        TestDecimalAverageAggregation.addToState(state, TWO.pow(126).negate());
        LongDecimalWithOverflowAndLongState otherState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        TestDecimalAverageAggregation.addToState(otherState, TWO.pow(125).negate());
        TestDecimalAverageAggregation.addToState(otherState, TWO.pow(126).negate());
        DecimalAverageAggregation.combine((LongDecimalWithOverflowAndLongState)state, (LongDecimalWithOverflowAndLongState)otherState);
        Assertions.assertThat((long)state.getLong()).isEqualTo(4L);
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(-1L);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Int128.valueOf((long)0x4000000000000000L, (long)0L));
        BigInteger expectedAverage = BigInteger.ZERO.add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(125)).add(TWO.pow(125)).negate().divide(BigInteger.valueOf(4L));
        this.assertAverageEquals(state, expectedAverage);
    }

    @Test
    public void testNoOverflow() {
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)BigInteger.TEN.pow(37), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)BigInteger.TEN.pow(37).negate(), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)TWO, (Object)BigInteger.ONE));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)BigInteger.ZERO, (Object)BigInteger.ONE));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)TWO.negate(), (Object)BigInteger.ONE.negate()));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)BigInteger.ONE.negate(), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)BigInteger.ONE.negate(), (Object)BigInteger.ZERO, (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)TWO.negate(), (Object)BigInteger.ZERO, (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)TWO.negate(), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)TWO_HUNDRED, (Object)ONE_HUNDRED));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)BigInteger.ZERO, (Object)ONE_HUNDRED));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)TWO_HUNDRED.negate(), (Object)ONE_HUNDRED.negate()));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)0), (List<BigInteger>)ImmutableList.of((Object)ONE_HUNDRED.negate(), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)BigInteger.TEN.pow(37), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)BigInteger.TEN.pow(37).negate(), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)TWO, (Object)BigInteger.ONE));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)BigInteger.ZERO, (Object)BigInteger.ONE));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)TWO.negate(), (Object)BigInteger.ONE.negate()));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)BigInteger.ONE.negate(), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)BigInteger.ONE.negate(), (Object)BigInteger.ZERO, (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)TWO.negate(), (Object)BigInteger.ZERO, (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)TWO.negate(), (Object)BigInteger.ZERO));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)TWO_HUNDRED, (Object)ONE_HUNDRED));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)BigInteger.ZERO, (Object)ONE_HUNDRED));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)TWO_HUNDRED.negate(), (Object)ONE_HUNDRED.negate()));
        this.testNoOverflow(DecimalType.createDecimalType((int)38, (int)2), (List<BigInteger>)ImmutableList.of((Object)ONE_HUNDRED.negate(), (Object)BigInteger.ZERO));
    }

    private void testNoOverflow(DecimalType type, List<BigInteger> numbers) {
        LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        for (BigInteger number : numbers) {
            TestDecimalAverageAggregation.addToState(type, state, number);
        }
        Assertions.assertThat((long)state.getOverflow()).isEqualTo(0L);
        BigInteger sum = numbers.stream().reduce(BigInteger.ZERO, BigInteger::add);
        Assertions.assertThat((Comparable)this.getDecimal(state)).isEqualTo((Object)Int128.valueOf((BigInteger)sum));
        BigDecimal expectedAverage = new BigDecimal(sum, type.getScale()).divide(BigDecimal.valueOf(numbers.size()), type.getScale(), RoundingMode.HALF_UP);
        Assertions.assertThat((BigDecimal)TestDecimalAverageAggregation.decodeBigDecimal(type, DecimalAverageAggregation.average((LongDecimalWithOverflowAndLongState)state, (DecimalType)type))).isEqualTo((Object)expectedAverage);
    }

    private static BigDecimal decodeBigDecimal(DecimalType type, Int128 average) {
        BigInteger unscaledVal = average.toBigInteger();
        return new BigDecimal(unscaledVal, type.getScale(), new MathContext(type.getPrecision()));
    }

    private void assertAverageEquals(LongDecimalWithOverflowAndLongState state, BigInteger expectedAverage) {
        Assertions.assertThat((BigInteger)DecimalAverageAggregation.average((LongDecimalWithOverflowAndLongState)state, (DecimalType)TYPE).toBigInteger()).isEqualTo((Object)expectedAverage);
    }

    private static void addToState(LongDecimalWithOverflowAndLongState state, BigInteger value) {
        TestDecimalAverageAggregation.addToState(TYPE, state, value);
    }

    private static void addToState(DecimalType type, LongDecimalWithOverflowAndLongState state, BigInteger value) {
        if (type.isShort()) {
            DecimalAverageAggregation.inputShortDecimal((LongDecimalWithOverflowAndLongState)state, (long)Int128.valueOf((BigInteger)value).toLongExact());
        } else {
            BlockBuilder blockBuilder = type.createFixedSizeBlockBuilder(1);
            type.writeObject(blockBuilder, (Object)Int128.valueOf((BigInteger)value));
            DecimalAverageAggregation.inputLongDecimal((LongDecimalWithOverflowAndLongState)state, (Int128ArrayBlock)((Int128ArrayBlock)blockBuilder.buildValueBlock()), (int)0);
        }
    }

    private Int128 getDecimal(LongDecimalWithOverflowAndLongState state) {
        long[] decimal = state.getDecimalArray();
        int offset = state.getDecimalArrayOffset();
        return Int128.valueOf((long)decimal[offset], (long)decimal[offset + 1]);
    }
}

