package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.block.BlockAssertions;
import io.trino.metadata.ResolvedFunction;
import io.trino.operator.PagesIndex;
import io.trino.operator.window.PagesWindowIndex;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.WindowAccumulator;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import java.util.List;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/aggregation/TestDoubleAverageAggregation.class */
public class TestDoubleAverageAggregation extends AbstractTestAggregationFunction {
    @Override // io.trino.operator.aggregation.AbstractTestAggregationFunction
    protected Block[] getSequenceBlocks(int i, int i2) {
        BlockBuilder createBlockBuilder = DoubleType.DOUBLE.createBlockBuilder((BlockBuilderStatus) null, i2);
        for (int i3 = i; i3 < i + i2; i3++) {
            DoubleType.DOUBLE.writeDouble(createBlockBuilder, i3);
        }
        return new Block[]{createBlockBuilder.build()};
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.trino.operator.aggregation.AbstractTestAggregationFunction
    public Number getExpectedValue(int i, int i2) {
        if (i2 == 0) {
            return null;
        }
        double d = 0.0d;
        for (int i3 = i; i3 < i + i2; i3++) {
            d += i3;
        }
        return Double.valueOf(d / i2);
    }

    @Override // io.trino.operator.aggregation.AbstractTestAggregationFunction
    protected String getFunctionName() {
        return "avg";
    }

    @Override // io.trino.operator.aggregation.AbstractTestAggregationFunction
    protected List<Type> getFunctionParameterTypes() {
        return ImmutableList.of(DoubleType.DOUBLE);
    }

    @Test
    public void testSlidingWindowForNaNAndInfinity() {
        int[] iArr = new int[12];
        Object[] objArr = new Object[12];
        Object[] objArr2 = new Object[12];
        for (int i = 0; i < 12; i++) {
            int min = Integer.min(i, (12 - 1) - i);
            iArr[i] = min;
            if (i >= 4) {
                objArr[i] = Double.valueOf(Double.NaN);
                objArr2[i] = Double.valueOf(Double.POSITIVE_INFINITY);
            } else {
                objArr[i] = getExpectedValue(i, min);
                objArr2[i] = getExpectedValue(i, min);
            }
        }
        Page page = new Page(12, TestDoubleSumAggregation.getSequenceBlocksForDoubleNaNTest(0, 12));
        PagesIndex newPagesIndex = new PagesIndex.TestingFactory(false).newPagesIndex(getFunctionParameterTypes(), 12);
        newPagesIndex.addPage(page);
        PagesWindowIndex pagesWindowIndex = new PagesWindowIndex(newPagesIndex, 0, 12 - 1);
        ResolvedFunction resolveFunction = this.functionResolution.resolveFunction(getFunctionName(), TypeSignatureProvider.fromTypes(getFunctionParameterTypes()));
        AggregationImplementation aggregationImplementation = this.functionResolution.getPlannerContext().getFunctionManager().getAggregationImplementation(resolveFunction);
        WindowAccumulator createWindowAccumulator = createWindowAccumulator(resolveFunction, aggregationImplementation);
        Assertions.assertThat(resolveFunction.signature().getReturnType().toString()).contains(new CharSequence[]{"double"});
        Assertions.assertThat(resolveFunction.signature().getName().toString()).contains(new CharSequence[]{"avg"});
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < 12; i4++) {
            int i5 = iArr[i4];
            for (int i6 = i2; i6 < i2 + i3; i6++) {
                if (i6 < i4 || i6 >= i4 + i5) {
                    boolean removeInput = createWindowAccumulator.removeInput(pagesWindowIndex, i6, i6);
                    if (i6 >= 4) {
                        Assertions.assertThat(removeInput).isFalse();
                    } else {
                        Assertions.assertThat(removeInput).isTrue();
                    }
                }
            }
            for (int i7 = i4; i7 < i4 + i5; i7++) {
                if (i7 < i2 || i7 >= i2 + i3) {
                    createWindowAccumulator.addInput(pagesWindowIndex, i7, i7);
                }
            }
            i2 = i4;
            i3 = i5;
            Type returnType = resolveFunction.signature().getReturnType();
            BlockBuilder createBlockBuilder = returnType.createBlockBuilder((BlockBuilderStatus) null, 1000);
            createWindowAccumulator.output(createBlockBuilder);
            Assertions.assertThat(AggregationTestUtils.makeValidityAssertion(objArr[i4]).apply(BlockAssertions.getOnlyValue(returnType, createBlockBuilder.build()), objArr[i4])).isTrue();
        }
        Page page2 = new Page(12, TestDoubleSumAggregation.getSequenceBlocksForDoubleInfinityTest(0, 12));
        PagesIndex newPagesIndex2 = new PagesIndex.TestingFactory(false).newPagesIndex(getFunctionParameterTypes(), 12);
        newPagesIndex2.addPage(page2);
        PagesWindowIndex pagesWindowIndex2 = new PagesWindowIndex(newPagesIndex2, 0, 12 - 1);
        WindowAccumulator createWindowAccumulator2 = createWindowAccumulator(resolveFunction, aggregationImplementation);
        int i8 = 0;
        int i9 = 0;
        for (int i10 = 0; i10 < 12; i10++) {
            int i11 = iArr[i10];
            for (int i12 = i8; i12 < i8 + i9; i12++) {
                if (i12 < i10 || i12 >= i10 + i11) {
                    boolean removeInput2 = createWindowAccumulator2.removeInput(pagesWindowIndex2, i12, i12);
                    if (i12 >= 4) {
                        Assertions.assertThat(removeInput2).isFalse();
                    } else {
                        Assertions.assertThat(removeInput2).isTrue();
                    }
                }
            }
            for (int i13 = i10; i13 < i10 + i11; i13++) {
                if (i13 < i8 || i13 >= i8 + i9) {
                    createWindowAccumulator2.addInput(pagesWindowIndex2, i13, i13);
                }
            }
            i8 = i10;
            i9 = i11;
            Type returnType2 = resolveFunction.signature().getReturnType();
            BlockBuilder createBlockBuilder2 = returnType2.createBlockBuilder((BlockBuilderStatus) null, 1000);
            createWindowAccumulator2.output(createBlockBuilder2);
            Assertions.assertThat(AggregationTestUtils.makeValidityAssertion(objArr2[i10]).apply(BlockAssertions.getOnlyValue(returnType2, createBlockBuilder2.build()), objArr2[i10])).isTrue();
        }
    }
}
