package com.amazonaws.athena.connector.lambda.examples;

import com.amazonaws.athena.connector.lambda.data.Block;
import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl;
import com.amazonaws.athena.connector.lambda.data.BlockUtils;
import com.amazonaws.athena.connector.lambda.data.FieldBuilder;
import com.amazonaws.athena.connector.lambda.data.S3BlockSpillReader;
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.predicate.AllOrNoneValueSet;
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
import com.amazonaws.athena.connector.lambda.domain.predicate.EquatableValueSet;
import com.amazonaws.athena.connector.lambda.domain.predicate.Range;
import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet;
import com.amazonaws.athena.connector.lambda.domain.spill.S3SpillLocation;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsResponse;
import com.amazonaws.athena.connector.lambda.records.RecordRequest;
import com.amazonaws.athena.connector.lambda.records.RecordRequestType;
import com.amazonaws.athena.connector.lambda.records.RecordResponse;
import com.amazonaws.athena.connector.lambda.records.RecordService;
import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse;
import com.amazonaws.athena.connector.lambda.security.EncryptionKey;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connector.lambda.security.IdentityUtil;
import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory;
import com.amazonaws.athena.connector.lambda.serde.ObjectMapperUtil;
import com.amazonaws.athena.connector.lambda.utils.TestUtils;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.PutObjectRequest;
import com.amazonaws.services.s3.model.PutObjectResult;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectInputStream;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteStreams;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.http.client.methods.HttpRequestBase;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/amazonaws/athena/connector/lambda/examples/ExampleRecordHandlerTest.class */
public class ExampleRecordHandlerTest {
    private static final Logger logger = LoggerFactory.getLogger(ExampleRecordHandlerTest.class);
    private RecordService recordService;
    private AmazonS3 amazonS3;
    private AWSSecretsManager awsSecretsManager;
    private AmazonAthena athena;
    private S3BlockSpillReader spillReader;
    private BlockAllocatorImpl allocator;
    private Schema schemaForRead;
    private EncryptionKeyFactory keyFactory = new LocalKeyFactory();
    private List<ByteHolder> mockS3Storage = new ArrayList();

    /* renamed from: com.amazonaws.athena.connector.lambda.examples.ExampleRecordHandlerTest$3, reason: invalid class name */
    /* loaded from: input_file:com/amazonaws/athena/connector/lambda/examples/ExampleRecordHandlerTest$3.class */
    static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$com$amazonaws$athena$connector$lambda$records$RecordRequestType = new int[RecordRequestType.values().length];

        static {
            try {
                $SwitchMap$com$amazonaws$athena$connector$lambda$records$RecordRequestType[RecordRequestType.READ_RECORDS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
        }
    }

    /* loaded from: input_file:com/amazonaws/athena/connector/lambda/examples/ExampleRecordHandlerTest$ByteHolder.class */
    private class ByteHolder {
        private byte[] bytes;

        private ByteHolder() {
        }

        public void setBytes(byte[] bArr) {
            this.bytes = bArr;
        }

        public byte[] getBytes() {
            return this.bytes;
        }
    }

    /* loaded from: input_file:com/amazonaws/athena/connector/lambda/examples/ExampleRecordHandlerTest$LocalHandler.class */
    private static class LocalHandler implements RecordService {
        private ExampleRecordHandler handler;
        private final BlockAllocatorImpl allocator;

        public LocalHandler(BlockAllocatorImpl blockAllocatorImpl, AmazonS3 amazonS3, AWSSecretsManager aWSSecretsManager, AmazonAthena amazonAthena, Map<String, String> map) {
            this.handler = new ExampleRecordHandler(amazonS3, aWSSecretsManager, amazonAthena, map);
            this.handler.setNumRows(20000);
            this.allocator = blockAllocatorImpl;
        }

        public RecordResponse readRecords(RecordRequest recordRequest) {
            try {
                switch (AnonymousClass3.$SwitchMap$com$amazonaws$athena$connector$lambda$records$RecordRequestType[recordRequest.getRequestType().ordinal()]) {
                    case TestUtils.SERDE_VERSION_ONE /* 1 */:
                        return this.handler.doReadRecords(this.allocator, (ReadRecordsRequest) recordRequest);
                    default:
                        throw new RuntimeException("Unknown request type " + recordRequest.getRequestType());
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    @Before
    public void setUp() {
        logger.info("setUpBefore - enter");
        this.schemaForRead = SchemaBuilder.newBuilder().addField("year", new ArrowType.Int(32, true)).addField("month", new ArrowType.Int(32, true)).addField("day", new ArrowType.Int(32, true)).addField("col2", new ArrowType.Utf8()).addField("col3", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)).addField("int", Types.MinorType.INT.getType()).addField("tinyint", Types.MinorType.TINYINT.getType()).addField("smallint", Types.MinorType.SMALLINT.getType()).addField("bigint", Types.MinorType.BIGINT.getType()).addField("float4", Types.MinorType.FLOAT4.getType()).addField("float8", Types.MinorType.FLOAT8.getType()).addField("bit", Types.MinorType.BIT.getType()).addField("varchar", Types.MinorType.VARCHAR.getType()).addField("varbinary", Types.MinorType.VARBINARY.getType()).addField("datemilli", Types.MinorType.DATEMILLI.getType()).addField("dateday", Types.MinorType.DATEDAY.getType()).addField("decimal", new ArrowType.Decimal(10, 2)).addField("decimalLong", new ArrowType.Decimal(36, 2)).addField(FieldBuilder.newBuilder("list", new ArrowType.List()).addField(FieldBuilder.newBuilder("innerStruct", Types.MinorType.STRUCT.getType()).addStringField("varchar").addBigIntField("bigint").build()).build()).addField(FieldBuilder.newBuilder("outerlist", new ArrowType.List()).addListField("innerList", Types.MinorType.VARCHAR.getType()).build()).addField(FieldBuilder.newBuilder("simplemap", new ArrowType.Map(false)).addField("entries", Types.MinorType.STRUCT.getType(), false, Arrays.asList(FieldBuilder.newBuilder("key", Types.MinorType.VARCHAR.getType(), false).build(), FieldBuilder.newBuilder("value", Types.MinorType.INT.getType()).build())).build()).addMetadata("partitionCols", "year,month,day").build();
        this.allocator = new BlockAllocatorImpl();
        this.amazonS3 = (AmazonS3) Mockito.mock(AmazonS3.class);
        this.awsSecretsManager = (AWSSecretsManager) Mockito.mock(AWSSecretsManager.class);
        this.athena = (AmazonAthena) Mockito.mock(AmazonAthena.class);
        Mockito.when(this.amazonS3.putObject((PutObjectRequest) ArgumentMatchers.any())).thenAnswer(new Answer<Object>() { // from class: com.amazonaws.athena.connector.lambda.examples.ExampleRecordHandlerTest.1
            public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
                InputStream inputStream = ((PutObjectRequest) invocationOnMock.getArguments()[0]).getInputStream();
                ByteHolder byteHolder = new ByteHolder();
                byteHolder.setBytes(ByteStreams.toByteArray(inputStream));
                ExampleRecordHandlerTest.this.mockS3Storage.add(byteHolder);
                return Mockito.mock(PutObjectResult.class);
            }
        });
        Mockito.when(this.amazonS3.getObject((String) ArgumentMatchers.nullable(String.class), (String) ArgumentMatchers.nullable(String.class))).thenAnswer(new Answer<Object>() { // from class: com.amazonaws.athena.connector.lambda.examples.ExampleRecordHandlerTest.2
            public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
                S3Object s3Object = (S3Object) Mockito.mock(S3Object.class);
                ByteHolder byteHolder = ExampleRecordHandlerTest.this.mockS3Storage.get(0);
                ExampleRecordHandlerTest.this.mockS3Storage.remove(0);
                Mockito.when(s3Object.getObjectContent()).thenReturn(new S3ObjectInputStream(new ByteArrayInputStream(byteHolder.getBytes()), (HttpRequestBase) null));
                return s3Object;
            }
        });
        this.recordService = new LocalHandler(this.allocator, this.amazonS3, this.awsSecretsManager, this.athena, ImmutableMap.of());
        this.spillReader = new S3BlockSpillReader(this.amazonS3, this.allocator);
        logger.info("setUpBefore - exit");
    }

    @After
    public void after() {
        this.allocator.close();
    }

    @Test
    public void doReadRecordsNoSpill() {
        logger.info("doReadRecordsNoSpill: enter");
        for (int i = 0; i < 2; i++) {
            EncryptionKey create = i % 2 == 0 ? this.keyFactory.create() : null;
            logger.info("doReadRecordsNoSpill: Using encryptionKey[" + create + "]");
            HashMap hashMap = new HashMap();
            hashMap.put("col3", SortedRangeSet.copyOf(Types.MinorType.FLOAT8.getType(), ImmutableList.of(Range.equal(this.allocator, Types.MinorType.FLOAT8.getType(), Double.valueOf(22.0d))), false));
            ReadRecordsRequest readRecordsRequest = new ReadRecordsRequest(IdentityUtil.fakeIdentity(), "catalog", "queryId-" + System.currentTimeMillis(), new TableName("schema", "table"), this.schemaForRead, Split.newBuilder(makeSpillLocation(), create).add("year", "10").add("month", "10").add("day", "10").build(), new Constraints(hashMap), 100000000000L, 100000000000L);
            ObjectMapperUtil.assertSerialization(readRecordsRequest);
            ReadRecordsResponse readRecords = this.recordService.readRecords(readRecordsRequest);
            ObjectMapperUtil.assertSerialization(readRecords);
            Assert.assertTrue(readRecords instanceof ReadRecordsResponse);
            ReadRecordsResponse readRecordsResponse = readRecords;
            logger.info("doReadRecordsNoSpill: rows[{}]", Integer.valueOf(readRecordsResponse.getRecordCount()));
            Assert.assertTrue(readRecordsResponse.getRecords().getRowCount() == 1);
            logger.info("doReadRecordsNoSpill: {}", BlockUtils.rowToString(readRecordsResponse.getRecords(), 0));
        }
        logger.info("doReadRecordsNoSpill: exit");
    }

    @Test
    public void doReadRecordsSpill() throws Exception {
        logger.info("doReadRecordsSpill: enter");
        for (int i = 0; i < 2; i++) {
            EncryptionKey create = i % 2 == 0 ? this.keyFactory.create() : null;
            logger.info("doReadRecordsSpill: Using encryptionKey[" + create + "]");
            HashMap hashMap = new HashMap();
            hashMap.put("col3", SortedRangeSet.copyOf(Types.MinorType.FLOAT8.getType(), ImmutableList.of(Range.greaterThan(this.allocator, Types.MinorType.FLOAT8.getType(), Double.valueOf(-10000.0d))), false));
            hashMap.put("unknown", EquatableValueSet.newBuilder(this.allocator, Types.MinorType.FLOAT8.getType(), false, true).add(Double.valueOf(1.1d)).build());
            hashMap.put("unknown2", new AllOrNoneValueSet(Types.MinorType.FLOAT8.getType(), false, true));
            ReadRecordsRequest readRecordsRequest = new ReadRecordsRequest(IdentityUtil.fakeIdentity(), "catalog", "queryId-" + System.currentTimeMillis(), new TableName("schema", "table"), this.schemaForRead, Split.newBuilder(makeSpillLocation(), create).add("year", "10").add("month", "10").add("day", "10").build(), new Constraints(hashMap), 1600000L, 1000L);
            ObjectMapperUtil.assertSerialization(readRecordsRequest);
            RemoteReadRecordsResponse readRecords = this.recordService.readRecords(readRecordsRequest);
            ObjectMapperUtil.assertSerialization(readRecords);
            Assert.assertTrue(readRecords instanceof RemoteReadRecordsResponse);
            RemoteReadRecordsResponse remoteReadRecordsResponse = readRecords;
            try {
                logger.info("doReadRecordsSpill: remoteBlocks[{}]", Integer.valueOf(remoteReadRecordsResponse.getRemoteBlocks().size()));
                Assert.assertTrue(remoteReadRecordsResponse.getNumberBlocks() > 1);
                int i2 = 0;
                Iterator it = remoteReadRecordsResponse.getRemoteBlocks().iterator();
                while (it.hasNext()) {
                    Block read = this.spillReader.read((SpillLocation) it.next(), remoteReadRecordsResponse.getEncryptionKey(), remoteReadRecordsResponse.getSchema());
                    try {
                        int i3 = i2;
                        i2++;
                        logger.info("doReadRecordsSpill: blockNum[{}] and recordCount[{}]", Integer.valueOf(i3), Integer.valueOf(read.getRowCount()));
                        logger.info("doReadRecordsSpill: {}", BlockUtils.rowToString(read, 0));
                        Assert.assertNotNull(BlockUtils.rowToString(read, 0));
                        if (read != null) {
                            read.close();
                        }
                    } finally {
                    }
                }
                if (remoteReadRecordsResponse != null) {
                    remoteReadRecordsResponse.close();
                }
            } catch (Throwable th) {
                if (remoteReadRecordsResponse != null) {
                    try {
                        remoteReadRecordsResponse.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        logger.info("doReadRecordsSpill: exit");
    }

    private SpillLocation makeSpillLocation() {
        return S3SpillLocation.newBuilder().withBucket("athena-virtuoso-test").withPrefix("lambda-spill").withQueryId(UUID.randomUUID().toString()).withSplitId(UUID.randomUUID().toString()).withIsDirectory(true).build();
    }
}
