package io.trino.operator.aggregation.state;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.trino.array.BooleanBigArray;
import io.trino.array.DoubleBigArray;
import io.trino.array.LongBigArray;
import io.trino.block.BlockAssertions;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.block.RowBlockBuilder;
import io.trino.spi.block.SqlMap;
import io.trino.spi.block.SqlRow;
import io.trino.spi.block.VariableWidthBlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.function.InOut;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.util.StructuralTestUtil;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler.class */
public class TestStateCompiler {

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$BooleanState.class */
    public interface BooleanState extends AccumulatorState {
        boolean isBoolean();

        void setBoolean(boolean z);
    }

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$ByteState.class */
    public interface ByteState extends AccumulatorState {
        byte getByte();

        void setByte(byte b);
    }

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$SliceState.class */
    public interface SliceState extends AccumulatorState {
        Slice getSlice();

        void setSlice(Slice slice);
    }

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$TestComplexState.class */
    public interface TestComplexState extends AccumulatorState {
        double getDouble();

        void setDouble(double d);

        boolean getBoolean();

        void setBoolean(boolean z);

        long getLong();

        void setLong(long j);

        byte getByte();

        void setByte(byte b);

        int getInt();

        void setInt(int i);

        Slice getSlice();

        void setSlice(Slice slice);

        Slice getAnotherSlice();

        void setAnotherSlice(Slice slice);

        Slice getYetAnotherSlice();

        void setYetAnotherSlice(Slice slice);

        Block getBlock();

        void setBlock(Block block);

        SqlMap getSqlMap();

        void setSqlMap(SqlMap sqlMap);

        SqlRow getSqlRow();

        void setSqlRow(SqlRow sqlRow);
    }

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$TestSimpleState.class */
    public interface TestSimpleState extends AccumulatorState {
        long getLong();

        void setLong(long j);

        double getDouble();

        void setDouble(double d);
    }

    @Test
    public void testPrimitiveNullableLongSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(NullableLongState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(NullableLongState.class);
        NullableLongState createSingleState = generateStateFactory.createSingleState();
        NullableLongState createSingleState2 = generateStateFactory.createSingleState();
        createSingleState.setValue(2L);
        createSingleState.setNull(false);
        BlockBuilder createBlockBuilder = BigintType.BIGINT.createBlockBuilder((BlockBuilderStatus) null, 2);
        generateStateSerializer.serialize(createSingleState, createBlockBuilder);
        createSingleState.setNull(true);
        generateStateSerializer.serialize(createSingleState, createBlockBuilder);
        Block build = createBlockBuilder.build();
        Assertions.assertThat(build.isNull(0)).isFalse();
        Assertions.assertThat(BigintType.BIGINT.getLong(build, 0)).isEqualTo(createSingleState.getValue());
        generateStateSerializer.deserialize(build, 0, createSingleState2);
        Assertions.assertThat(createSingleState2.getValue()).isEqualTo(createSingleState.getValue());
        Assertions.assertThat(build.isNull(1)).isTrue();
    }

    @Test
    public void testPrimitiveLongSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(LongState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(LongState.class);
        LongState createSingleState = generateStateFactory.createSingleState();
        LongState createSingleState2 = generateStateFactory.createSingleState();
        createSingleState.setValue(2L);
        BlockBuilder createBlockBuilder = BigintType.BIGINT.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(createSingleState, createBlockBuilder);
        Block build = createBlockBuilder.build();
        Assertions.assertThat(BigintType.BIGINT.getLong(build, 0)).isEqualTo(createSingleState.getValue());
        generateStateSerializer.deserialize(build, 0, createSingleState2);
        Assertions.assertThat(createSingleState2.getValue()).isEqualTo(createSingleState.getValue());
    }

    @Test
    public void testGetSerializedType() {
        Assertions.assertThat(StateCompiler.generateStateSerializer(LongState.class).getSerializedType()).isEqualTo(BigintType.BIGINT);
    }

    @Test
    public void testPrimitiveBooleanSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(BooleanState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(BooleanState.class);
        BooleanState booleanState = (BooleanState) generateStateFactory.createSingleState();
        BooleanState booleanState2 = (BooleanState) generateStateFactory.createSingleState();
        booleanState.setBoolean(true);
        BlockBuilder createBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(booleanState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, booleanState2);
        Assertions.assertThat(booleanState2.isBoolean()).isEqualTo(booleanState.isBoolean());
    }

    @Test
    public void testPrimitiveByteSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(ByteState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(ByteState.class);
        ByteState byteState = (ByteState) generateStateFactory.createSingleState();
        ByteState byteState2 = (ByteState) generateStateFactory.createSingleState();
        byteState.setByte((byte) 3);
        BlockBuilder createBlockBuilder = TinyintType.TINYINT.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(byteState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, byteState2);
        Assertions.assertThat(byteState2.getByte()).isEqualTo(byteState.getByte());
    }

    @Test
    public void testNonPrimitiveSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(SliceState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(SliceState.class);
        SliceState sliceState = (SliceState) generateStateFactory.createSingleState();
        SliceState sliceState2 = (SliceState) generateStateFactory.createSingleState();
        sliceState.setSlice(null);
        VariableWidthBlockBuilder createBlockBuilder = VarcharType.VARCHAR.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(sliceState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, sliceState2);
        Assertions.assertThat(sliceState2.getSlice()).isEqualTo(sliceState.getSlice());
        sliceState.setSlice(Slices.utf8Slice("test"));
        VariableWidthBlockBuilder createBlockBuilder2 = VarcharType.VARCHAR.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(sliceState, createBlockBuilder2);
        generateStateSerializer.deserialize(createBlockBuilder2.build(), 0, sliceState2);
        Assertions.assertThat(sliceState2.getSlice()).isEqualTo(sliceState.getSlice());
    }

    @Test
    public void testVarianceStateSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(VarianceState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(VarianceState.class);
        VarianceState createSingleState = generateStateFactory.createSingleState();
        VarianceState createSingleState2 = generateStateFactory.createSingleState();
        createSingleState.setMean(1.0d);
        createSingleState.setCount(2L);
        createSingleState.setM2(3.0d);
        RowBlockBuilder createBlockBuilder = RowType.anonymous(ImmutableList.of(BigintType.BIGINT, DoubleType.DOUBLE, DoubleType.DOUBLE)).createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(createSingleState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, createSingleState2);
        Assertions.assertThat(createSingleState2.getCount()).isEqualTo(createSingleState.getCount());
        Assertions.assertThat(createSingleState2.getMean()).isEqualTo(createSingleState.getMean());
        Assertions.assertThat(createSingleState2.getM2()).isEqualTo(createSingleState.getM2());
    }

    @Test
    public void testComplexSerialization() {
        ImmutableMap of = ImmutableMap.of("Block", new ArrayType(BigintType.BIGINT), "SqlMap", StructuralTestUtil.mapType(BigintType.BIGINT, VarcharType.VARCHAR), "SqlRow", RowType.anonymousRow(new Type[]{VarcharType.VARCHAR, BigintType.BIGINT, VarcharType.VARCHAR}));
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(TestComplexState.class, of);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(TestComplexState.class, of);
        TestComplexState testComplexState = (TestComplexState) generateStateFactory.createSingleState();
        TestComplexState testComplexState2 = (TestComplexState) generateStateFactory.createSingleState();
        testComplexState.setBoolean(true);
        testComplexState.setLong(1L);
        testComplexState.setDouble(2.0d);
        testComplexState.setByte((byte) 3);
        testComplexState.setInt(4);
        testComplexState.setSlice(Slices.utf8Slice("test"));
        testComplexState.setAnotherSlice(toSlice(1.0d, 2.0d, 3.0d));
        testComplexState.setYetAnotherSlice(null);
        testComplexState.setBlock(BlockAssertions.createLongsBlock(45));
        testComplexState.setSqlMap(StructuralTestUtil.sqlMapOf(BigintType.BIGINT, VarcharType.VARCHAR, ImmutableMap.of(123L, "testBlock")));
        testComplexState.setSqlRow(StructuralTestUtil.sqlRowOf(RowType.anonymousRow(new Type[]{VarcharType.VARCHAR, BigintType.BIGINT, VarcharType.VARCHAR}), "a", 777, "b"));
        BlockBuilder createBlockBuilder = generateStateSerializer.getSerializedType().createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(testComplexState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, testComplexState2);
        Assertions.assertThat(testComplexState2.getBoolean()).isEqualTo(testComplexState.getBoolean());
        Assertions.assertThat(testComplexState2.getLong()).isEqualTo(testComplexState.getLong());
        Assertions.assertThat(testComplexState2.getDouble()).isEqualTo(testComplexState.getDouble());
        Assertions.assertThat(testComplexState2.getByte()).isEqualTo(testComplexState.getByte());
        Assertions.assertThat(testComplexState2.getInt()).isEqualTo(testComplexState.getInt());
        Assertions.assertThat(testComplexState2.getSlice()).isEqualTo(testComplexState.getSlice());
        Assertions.assertThat(testComplexState2.getAnotherSlice()).isEqualTo(testComplexState.getAnotherSlice());
        Assertions.assertThat(testComplexState2.getYetAnotherSlice()).isEqualTo(testComplexState.getYetAnotherSlice());
        Assertions.assertThat(BigintType.BIGINT.getLong(testComplexState2.getBlock(), 0)).isEqualTo(BigintType.BIGINT.getLong(testComplexState.getBlock(), 0));
        SqlMap sqlMap = testComplexState2.getSqlMap();
        SqlMap sqlMap2 = testComplexState.getSqlMap();
        Assertions.assertThat(BigintType.BIGINT.getLong(sqlMap.getRawKeyBlock(), sqlMap.getRawOffset())).isEqualTo(BigintType.BIGINT.getLong(sqlMap2.getRawKeyBlock(), sqlMap2.getRawOffset()));
        Assertions.assertThat(VarcharType.VARCHAR.getSlice(sqlMap.getRawValueBlock(), sqlMap.getRawOffset())).isEqualTo(VarcharType.VARCHAR.getSlice(sqlMap2.getRawValueBlock(), sqlMap2.getRawOffset()));
        SqlRow sqlRow = testComplexState2.getSqlRow();
        SqlRow sqlRow2 = testComplexState.getSqlRow();
        Assertions.assertThat(VarcharType.VARCHAR.getSlice(sqlRow.getRawFieldBlock(0), sqlRow.getRawIndex())).isEqualTo(VarcharType.VARCHAR.getSlice(sqlRow2.getRawFieldBlock(0), sqlRow2.getRawIndex()));
        Assertions.assertThat(BigintType.BIGINT.getLong(sqlRow.getRawFieldBlock(1), sqlRow.getRawIndex())).isEqualTo(BigintType.BIGINT.getLong(sqlRow2.getRawFieldBlock(1), sqlRow2.getRawIndex()));
        Assertions.assertThat(VarcharType.VARCHAR.getSlice(sqlRow.getRawFieldBlock(2), sqlRow.getRawIndex())).isEqualTo(VarcharType.VARCHAR.getSlice(sqlRow2.getRawFieldBlock(2), sqlRow2.getRawIndex()));
    }

    @Test
    public void testEstimatedInOutStatesInstanceSizes() {
        AccumulatorStateFactory generateInOutStateFactory = StateCompiler.generateInOutStateFactory(BigintType.BIGINT);
        InOut createGroupedState = generateInOutStateFactory.createGroupedState();
        InOut createSingleState = generateInOutStateFactory.createSingleState();
        Assertions.assertThat(createGroupedState.getEstimatedSize()).isEqualTo(SizeOf.instanceSize(createGroupedState.getClass()) + new LongBigArray().sizeOf() + new BooleanBigArray().sizeOf()).isEqualTo(17568L);
        Assertions.assertThat(createSingleState.getEstimatedSize()).isEqualTo(SizeOf.instanceSize(createSingleState.getClass())).isEqualTo(24L);
    }

    @Test
    public void testEstimatedStateInstanceSizes() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(TestSimpleState.class);
        TestSimpleState testSimpleState = (TestSimpleState) generateStateFactory.createGroupedState();
        TestSimpleState testSimpleState2 = (TestSimpleState) generateStateFactory.createSingleState();
        Assertions.assertThat(testSimpleState.getEstimatedSize()).isEqualTo(SizeOf.instanceSize(testSimpleState.getClass()) + new LongBigArray().sizeOf() + new DoubleBigArray().sizeOf()).isEqualTo(24744L);
        Assertions.assertThat(testSimpleState2.getEstimatedSize()).isEqualTo(SizeOf.instanceSize(testSimpleState2.getClass())).isEqualTo(32L);
    }

    private static Slice toSlice(double... dArr) {
        Slice allocate = Slices.allocate(dArr.length * 8);
        SliceOutput output = allocate.getOutput();
        for (double d : dArr) {
            output.writeDouble(d);
        }
        return allocate;
    }
}
