package io.trino.operator.aggregation;

import com.google.common.primitives.Ints;
import io.trino.block.BlockAssertions;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.operator.GroupByIdBlock;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import java.util.Collections;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.apache.commons.math3.util.Precision;
import org.testng.Assert;

/* loaded from: input_file:io/trino/operator/aggregation/AggregationTestUtils.class */
public final class AggregationTestUtils {
    private AggregationTestUtils() {
    }

    public static void assertAggregation(Metadata metadata, ResolvedFunction resolvedFunction, Object obj, Block... blockArr) {
        assertAggregation(metadata, resolvedFunction, obj, new Page(blockArr));
    }

    public static void assertAggregation(Metadata metadata, ResolvedFunction resolvedFunction, Object obj, Page page) {
        assertAggregation(metadata, resolvedFunction, makeValidityAssertion(obj), null, page, obj);
    }

    public static BiFunction<Object, Object, Boolean> makeValidityAssertion(Object obj) {
        return (!(obj instanceof Double) || obj.equals(Double.valueOf(Double.NaN))) ? (!(obj instanceof Float) || obj.equals(Float.valueOf(Float.NaN))) ? Objects::equals : (obj2, obj3) -> {
            return Boolean.valueOf(Precision.equals(((Float) obj2).floatValue(), ((Float) obj3).floatValue(), 1.0E-10f));
        } : (obj4, obj5) -> {
            return Boolean.valueOf(Precision.equals(((Double) obj4).doubleValue(), ((Double) obj5).doubleValue(), 1.0E-10d));
        };
    }

    public static void assertAggregation(Metadata metadata, ResolvedFunction resolvedFunction, BiFunction<Object, Object, Boolean> biFunction, String str, Page page, Object obj) {
        AggregationFunctionMetadata aggregationFunctionMetadata = metadata.getAggregationFunctionMetadata(resolvedFunction);
        InternalAggregationFunction aggregateFunctionImplementation = metadata.getAggregateFunctionImplementation(resolvedFunction);
        Assert.assertEquals(aggregateFunctionImplementation.getParameterTypes(), resolvedFunction.getSignature().getArgumentTypes());
        Assert.assertEquals(aggregateFunctionImplementation.getFinalType(), resolvedFunction.getSignature().getReturnType());
        Assert.assertEquals(aggregateFunctionImplementation.getIntermediateType().map((v0) -> {
            return v0.getTypeSignature();
        }), aggregationFunctionMetadata.getIntermediateType());
        int positionCount = page.getPositionCount();
        for (int i = 1; i < page.getChannelCount(); i++) {
            Assert.assertEquals(positionCount, page.getBlock(i).getPositionCount(), "input blocks provided are not equal in position count");
        }
        if (positionCount == 0) {
            assertAggregationInternal(aggregateFunctionImplementation, biFunction, str, obj, new Page[0]);
        } else if (positionCount == 1) {
            assertAggregationInternal(aggregateFunctionImplementation, biFunction, str, obj, page);
        } else {
            int i2 = positionCount / 2;
            assertAggregationInternal(aggregateFunctionImplementation, biFunction, str, obj, page.getRegion(0, i2), page.getRegion(i2, positionCount - i2));
        }
    }

    public static Block getIntermediateBlock(Accumulator accumulator) {
        BlockBuilder createBlockBuilder = accumulator.getIntermediateType().createBlockBuilder((BlockBuilderStatus) null, 1000);
        accumulator.evaluateIntermediate(createBlockBuilder);
        return createBlockBuilder.build();
    }

    public static Block getIntermediateBlock(GroupedAccumulator groupedAccumulator) {
        BlockBuilder createBlockBuilder = groupedAccumulator.getIntermediateType().createBlockBuilder((BlockBuilderStatus) null, 1000);
        groupedAccumulator.evaluateIntermediate(0, createBlockBuilder);
        return createBlockBuilder.build();
    }

    public static Block getFinalBlock(Accumulator accumulator) {
        BlockBuilder createBlockBuilder = accumulator.getFinalType().createBlockBuilder((BlockBuilderStatus) null, 1000);
        accumulator.evaluateFinal(createBlockBuilder);
        return createBlockBuilder.build();
    }

    public static Block getFinalBlock(GroupedAccumulator groupedAccumulator) {
        BlockBuilder createBlockBuilder = groupedAccumulator.getFinalType().createBlockBuilder((BlockBuilderStatus) null, 1000);
        groupedAccumulator.evaluateFinal(0, createBlockBuilder);
        return createBlockBuilder.build();
    }

    private static void assertAggregationInternal(InternalAggregationFunction internalAggregationFunction, BiFunction<Object, Object, Boolean> biFunction, String str, Object obj, Page... pageArr) {
        assertFunctionEquals(biFunction, str, aggregation(internalAggregationFunction, pageArr), obj);
        assertFunctionEquals(biFunction, str, partialAggregation(internalAggregationFunction, pageArr), obj);
        if (pageArr.length > 0) {
            assertFunctionEquals(biFunction, str, groupedAggregation(biFunction, internalAggregationFunction, pageArr), obj);
            assertFunctionEquals(biFunction, str, groupedPartialAggregation(biFunction, internalAggregationFunction, pageArr), obj);
            assertFunctionEquals(biFunction, str, distinctAggregation(internalAggregationFunction, pageArr), obj);
        }
    }

    private static void assertFunctionEquals(BiFunction<Object, Object, Boolean> biFunction, String str, Object obj, Object obj2) {
        if (biFunction.apply(obj, obj2).booleanValue()) {
            return;
        }
        StringBuilder sb = new StringBuilder();
        if (str != null) {
            sb.append(String.format("Test: %s, ", str));
        }
        sb.append(String.format("Expected: %s, actual: %s", obj2, obj));
        Assert.fail(sb.toString());
    }

    public static Object distinctAggregation(InternalAggregationFunction internalAggregationFunction, Page... pageArr) {
        Optional of = Optional.of(Integer.valueOf(pageArr[0].getChannelCount()));
        Object aggregation = aggregation(internalAggregationFunction, createArgs(internalAggregationFunction), of, maskPages(true, pageArr));
        Page[] pageArr2 = new Page[pageArr.length * 2];
        System.arraycopy(maskPages(true, pageArr), 0, pageArr2, 0, pageArr.length);
        System.arraycopy(maskPages(false, pageArr), 0, pageArr2, pageArr.length, pageArr.length);
        Assert.assertEquals(aggregation(internalAggregationFunction, createArgs(internalAggregationFunction), of, pageArr2), aggregation, "Inconsistent results with mask");
        System.arraycopy(maskPagesWithRle(true, pageArr), 0, pageArr2, 0, pageArr.length);
        System.arraycopy(maskPagesWithRle(false, pageArr), 0, pageArr2, pageArr.length, pageArr.length);
        Assert.assertEquals(aggregation(internalAggregationFunction, createArgs(internalAggregationFunction), of, pageArr2), aggregation, "Inconsistent results with RLE mask");
        return aggregation;
    }

    private static Page[] maskPagesWithRle(boolean z, Page... pageArr) {
        Page[] pageArr2 = new Page[pageArr.length];
        for (int i = 0; i < pageArr.length; i++) {
            Page page = pageArr[i];
            pageArr2[i] = page.appendColumn(new RunLengthEncodedBlock(BooleanType.createBlockForSingleNonNullValue(z), page.getPositionCount()));
        }
        return pageArr2;
    }

    private static Page[] maskPages(boolean z, Page... pageArr) {
        Page[] pageArr2 = new Page[pageArr.length];
        for (int i = 0; i < pageArr.length; i++) {
            Page page = pageArr[i];
            BlockBuilder createBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder((BlockBuilderStatus) null, page.getPositionCount());
            for (int i2 = 0; i2 < page.getPositionCount(); i2++) {
                BooleanType.BOOLEAN.writeBoolean(createBlockBuilder, z);
            }
            pageArr2[i] = page.appendColumn(createBlockBuilder.build());
        }
        return pageArr2;
    }

    public static Object aggregation(InternalAggregationFunction internalAggregationFunction, Page... pageArr) {
        Object aggregation = aggregation(internalAggregationFunction, createArgs(internalAggregationFunction), Optional.empty(), pageArr);
        if (internalAggregationFunction.getParameterTypes().size() > 1) {
            Assert.assertEquals(aggregation(internalAggregationFunction, reverseArgs(internalAggregationFunction), Optional.empty(), reverseColumns(pageArr)), aggregation, "Inconsistent results with reversed channels");
        }
        Assert.assertEquals(aggregation(internalAggregationFunction, offsetArgs(internalAggregationFunction, 3), Optional.empty(), offsetColumns(pageArr, 3)), aggregation, "Inconsistent results with channel offset");
        return aggregation;
    }

    private static Object aggregation(InternalAggregationFunction internalAggregationFunction, int[] iArr, Optional<Integer> optional, Page... pageArr) {
        Accumulator createAccumulator = internalAggregationFunction.bind(Ints.asList(iArr), optional).createAccumulator();
        for (Page page : pageArr) {
            if (page.getPositionCount() > 0) {
                createAccumulator.addInput(page);
            }
        }
        return BlockAssertions.getOnlyValue(createAccumulator.getFinalType(), getFinalBlock(createAccumulator));
    }

    public static Object partialAggregation(InternalAggregationFunction internalAggregationFunction, Page... pageArr) {
        Object partialAggregation = partialAggregation(internalAggregationFunction, createArgs(internalAggregationFunction), pageArr);
        if (internalAggregationFunction.getParameterTypes().size() > 1) {
            Assert.assertEquals(partialAggregation(internalAggregationFunction, reverseArgs(internalAggregationFunction), reverseColumns(pageArr)), partialAggregation, "Inconsistent results with reversed channels");
        }
        Assert.assertEquals(partialAggregation(internalAggregationFunction, offsetArgs(internalAggregationFunction, 3), offsetColumns(pageArr, 3)), partialAggregation, "Inconsistent results with channel offset");
        return partialAggregation;
    }

    public static Object partialAggregation(InternalAggregationFunction internalAggregationFunction, int[] iArr, Page... pageArr) {
        AccumulatorFactory bind = internalAggregationFunction.bind(Ints.asList(iArr), Optional.empty());
        Accumulator createIntermediateAccumulator = bind.createIntermediateAccumulator();
        Block intermediateBlock = getIntermediateBlock(bind.createAccumulator());
        createIntermediateAccumulator.addIntermediate(intermediateBlock);
        for (Page page : pageArr) {
            Accumulator createAccumulator = bind.createAccumulator();
            if (page.getPositionCount() > 0) {
                createAccumulator.addInput(page);
            }
            createIntermediateAccumulator.addIntermediate(getIntermediateBlock(createAccumulator));
        }
        createIntermediateAccumulator.addIntermediate(intermediateBlock);
        return BlockAssertions.getOnlyValue(createIntermediateAccumulator.getFinalType(), getFinalBlock(createIntermediateAccumulator));
    }

    public static Object groupedAggregation(InternalAggregationFunction internalAggregationFunction, Page... pageArr) {
        return groupedAggregation((BiFunction<Object, Object, Boolean>) Objects::equals, internalAggregationFunction, pageArr);
    }

    public static Object groupedAggregation(BiFunction<Object, Object, Boolean> biFunction, InternalAggregationFunction internalAggregationFunction, Page... pageArr) {
        Object groupedAggregation = groupedAggregation(internalAggregationFunction, createArgs(internalAggregationFunction), pageArr);
        if (internalAggregationFunction.getParameterTypes().size() > 1) {
            assertFunctionEquals(biFunction, "Inconsistent results with reversed channels", groupedAggregation(internalAggregationFunction, reverseArgs(internalAggregationFunction), reverseColumns(pageArr)), groupedAggregation);
        }
        assertFunctionEquals(biFunction, "Consistent results with channel offset", groupedAggregation(internalAggregationFunction, offsetArgs(internalAggregationFunction, 3), offsetColumns(pageArr, 3)), groupedAggregation);
        return groupedAggregation;
    }

    public static Object groupedAggregation(InternalAggregationFunction internalAggregationFunction, int[] iArr, Page... pageArr) {
        GroupedAccumulator createGroupedAccumulator = internalAggregationFunction.bind(Ints.asList(iArr), Optional.empty()).createGroupedAccumulator();
        for (Page page : pageArr) {
            createGroupedAccumulator.addInput(createGroupByIdBlock(0, page.getPositionCount()), page);
        }
        Object groupValue = getGroupValue(createGroupedAccumulator, 0);
        for (Page page2 : pageArr) {
            createGroupedAccumulator.addInput(createGroupByIdBlock(4000, page2.getPositionCount()), page2);
        }
        Assert.assertEquals(getGroupValue(createGroupedAccumulator, 4000), groupValue, "Inconsistent results with large group id");
        return groupValue;
    }

    public static Object groupedPartialAggregation(BiFunction<Object, Object, Boolean> biFunction, InternalAggregationFunction internalAggregationFunction, Page... pageArr) {
        Object groupedPartialAggregation = groupedPartialAggregation(internalAggregationFunction, createArgs(internalAggregationFunction), pageArr);
        if (internalAggregationFunction.getParameterTypes().size() > 1) {
            assertFunctionEquals(biFunction, "Consistent results with reversed channels", groupedPartialAggregation(internalAggregationFunction, reverseArgs(internalAggregationFunction), reverseColumns(pageArr)), groupedPartialAggregation);
        }
        assertFunctionEquals(biFunction, "Consistent results with channel offset", groupedPartialAggregation(internalAggregationFunction, offsetArgs(internalAggregationFunction, 3), offsetColumns(pageArr, 3)), groupedPartialAggregation);
        return groupedPartialAggregation;
    }

    public static Object groupedPartialAggregation(InternalAggregationFunction internalAggregationFunction, int[] iArr, Page... pageArr) {
        AccumulatorFactory bind = internalAggregationFunction.bind(Ints.asList(iArr), Optional.empty());
        GroupedAccumulator createGroupedIntermediateAccumulator = bind.createGroupedIntermediateAccumulator();
        Block intermediateBlock = getIntermediateBlock(bind.createGroupedAccumulator());
        createGroupedIntermediateAccumulator.addIntermediate(createGroupByIdBlock(0, intermediateBlock.getPositionCount()), intermediateBlock);
        for (Page page : pageArr) {
            GroupedAccumulator createGroupedAccumulator = bind.createGroupedAccumulator();
            createGroupedAccumulator.addInput(createGroupByIdBlock(0, page.getPositionCount()), page);
            Block intermediateBlock2 = getIntermediateBlock(createGroupedAccumulator);
            createGroupedIntermediateAccumulator.addIntermediate(createGroupByIdBlock(0, intermediateBlock2.getPositionCount()), intermediateBlock2);
        }
        createGroupedIntermediateAccumulator.addIntermediate(createGroupByIdBlock(0, intermediateBlock.getPositionCount()), intermediateBlock);
        return getGroupValue(createGroupedIntermediateAccumulator, 0);
    }

    public static GroupByIdBlock createGroupByIdBlock(int i, int i2) {
        BlockBuilder createBlockBuilder = BigintType.BIGINT.createBlockBuilder((BlockBuilderStatus) null, i2);
        for (int i3 = 0; i3 < i2; i3++) {
            BigintType.BIGINT.writeLong(createBlockBuilder, i);
        }
        return new GroupByIdBlock(i, createBlockBuilder.build());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int[] createArgs(InternalAggregationFunction internalAggregationFunction) {
        int[] iArr = new int[internalAggregationFunction.getParameterTypes().size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
        }
        return iArr;
    }

    public static int[] reverseArgs(InternalAggregationFunction internalAggregationFunction) {
        int[] createArgs = createArgs(internalAggregationFunction);
        Collections.reverse(Ints.asList(createArgs));
        return createArgs;
    }

    public static int[] offsetArgs(InternalAggregationFunction internalAggregationFunction, int i) {
        int[] createArgs = createArgs(internalAggregationFunction);
        for (int i2 = 0; i2 < createArgs.length; i2++) {
            int i3 = i2;
            createArgs[i3] = createArgs[i3] + i;
        }
        return createArgs;
    }

    public static Page[] reverseColumns(Page[] pageArr) {
        Page[] pageArr2 = new Page[pageArr.length];
        for (int i = 0; i < pageArr.length; i++) {
            Page page = pageArr[i];
            if (page.getPositionCount() == 0) {
                pageArr2[i] = page;
            } else {
                Block[] blockArr = new Block[page.getChannelCount()];
                for (int i2 = 0; i2 < page.getChannelCount(); i2++) {
                    blockArr[i2] = page.getBlock((page.getChannelCount() - i2) - 1);
                }
                pageArr2[i] = new Page(page.getPositionCount(), blockArr);
            }
        }
        return pageArr2;
    }

    public static Page[] offsetColumns(Page[] pageArr, int i) {
        Page[] pageArr2 = new Page[pageArr.length];
        for (int i2 = 0; i2 < pageArr.length; i2++) {
            Page page = pageArr[i2];
            Block[] blockArr = new Block[page.getChannelCount() + i];
            for (int i3 = 0; i3 < i; i3++) {
                blockArr[i3] = createNullRLEBlock(page.getPositionCount());
            }
            for (int i4 = 0; i4 < page.getChannelCount(); i4++) {
                blockArr[i4 + i] = page.getBlock(i4);
            }
            pageArr2[i2] = new Page(page.getPositionCount(), blockArr);
        }
        return pageArr2;
    }

    private static RunLengthEncodedBlock createNullRLEBlock(int i) {
        return RunLengthEncodedBlock.create(BooleanType.BOOLEAN, (Object) null, i);
    }

    public static Object getGroupValue(GroupedAccumulator groupedAccumulator, int i) {
        BlockBuilder createBlockBuilder = groupedAccumulator.getFinalType().createBlockBuilder((BlockBuilderStatus) null, 1);
        groupedAccumulator.evaluateFinal(i, createBlockBuilder);
        return BlockAssertions.getOnlyValue(groupedAccumulator.getFinalType(), createBlockBuilder.build());
    }

    public static double[] constructDoublePrimitiveArray(int i, int i2) {
        return IntStream.range(i, i + i2).asDoubleStream().toArray();
    }
}
