/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.athena.connector.lambda.examples;

import com.amazonaws.athena.connector.lambda.data.Block;
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl;
import com.amazonaws.athena.connector.lambda.data.BlockUtils;
import com.amazonaws.athena.connector.lambda.data.FieldBuilder;
import com.amazonaws.athena.connector.lambda.data.S3BlockSpillReader;
import com.amazonaws.athena.connector.lambda.data.SchemaBuilder;
import com.amazonaws.athena.connector.lambda.domain.Split;
import com.amazonaws.athena.connector.lambda.domain.TableName;
import com.amazonaws.athena.connector.lambda.domain.predicate.AllOrNoneValueSet;
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
import com.amazonaws.athena.connector.lambda.domain.predicate.EquatableValueSet;
import com.amazonaws.athena.connector.lambda.domain.predicate.Range;
import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet;
import com.amazonaws.athena.connector.lambda.domain.spill.S3SpillLocation;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.examples.ExampleRecordHandler;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsResponse;
import com.amazonaws.athena.connector.lambda.records.RecordRequest;
import com.amazonaws.athena.connector.lambda.records.RecordResponse;
import com.amazonaws.athena.connector.lambda.records.RecordService;
import com.amazonaws.athena.connector.lambda.records.RemoteReadRecordsResponse;
import com.amazonaws.athena.connector.lambda.security.EncryptionKey;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connector.lambda.security.IdentityUtil;
import com.amazonaws.athena.connector.lambda.security.LocalKeyFactory;
import com.amazonaws.athena.connector.lambda.serde.ObjectMapperUtil;
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.PutObjectRequest;
import com.amazonaws.services.s3.model.PutObjectResult;
import com.amazonaws.services.s3.model.S3Object;
import com.amazonaws.services.s3.model.S3ObjectInputStream;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.collect.ImmutableList;
import com.google.common.io.ByteStreams;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExampleRecordHandlerTest {
    private static final Logger logger = LoggerFactory.getLogger(ExampleRecordHandlerTest.class);
    private EncryptionKeyFactory keyFactory = new LocalKeyFactory();
    private RecordService recordService;
    private List<ByteHolder> mockS3Storage = new ArrayList<ByteHolder>();
    private AmazonS3 amazonS3;
    private AWSSecretsManager awsSecretsManager;
    private AmazonAthena athena;
    private S3BlockSpillReader spillReader;
    private BlockAllocatorImpl allocator;
    private Schema schemaForRead;

    @Before
    public void setUp() {
        logger.info("setUpBefore - enter");
        this.schemaForRead = SchemaBuilder.newBuilder().addField("year", (ArrowType)new ArrowType.Int(32, true)).addField("month", (ArrowType)new ArrowType.Int(32, true)).addField("day", (ArrowType)new ArrowType.Int(32, true)).addField("col2", (ArrowType)new ArrowType.Utf8()).addField("col3", (ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)).addField("int", Types.MinorType.INT.getType()).addField("tinyint", Types.MinorType.TINYINT.getType()).addField("smallint", Types.MinorType.SMALLINT.getType()).addField("bigint", Types.MinorType.BIGINT.getType()).addField("float4", Types.MinorType.FLOAT4.getType()).addField("float8", Types.MinorType.FLOAT8.getType()).addField("bit", Types.MinorType.BIT.getType()).addField("varchar", Types.MinorType.VARCHAR.getType()).addField("varbinary", Types.MinorType.VARBINARY.getType()).addField("datemilli", Types.MinorType.DATEMILLI.getType()).addField("dateday", Types.MinorType.DATEDAY.getType()).addField("decimal", (ArrowType)new ArrowType.Decimal(10, 2)).addField("decimalLong", (ArrowType)new ArrowType.Decimal(36, 2)).addField(FieldBuilder.newBuilder((String)"list", (ArrowType)new ArrowType.List()).addField(FieldBuilder.newBuilder((String)"innerStruct", (ArrowType)Types.MinorType.STRUCT.getType()).addStringField("varchar").addBigIntField("bigint").build()).build()).addField(FieldBuilder.newBuilder((String)"outerlist", (ArrowType)new ArrowType.List()).addListField("innerList", Types.MinorType.VARCHAR.getType()).build()).addField(FieldBuilder.newBuilder((String)"simplemap", (ArrowType)new ArrowType.Map(false)).addField("entries", Types.MinorType.STRUCT.getType(), false, Arrays.asList(FieldBuilder.newBuilder((String)"key", (ArrowType)Types.MinorType.VARCHAR.getType(), (boolean)false).build(), FieldBuilder.newBuilder((String)"value", (ArrowType)Types.MinorType.INT.getType()).build())).build()).addMetadata("partitionCols", "year,month,day").build();
        this.allocator = new BlockAllocatorImpl();
        this.amazonS3 = (AmazonS3)Mockito.mock(AmazonS3.class);
        this.awsSecretsManager = (AWSSecretsManager)Mockito.mock(AWSSecretsManager.class);
        this.athena = (AmazonAthena)Mockito.mock(AmazonAthena.class);
        Mockito.when((Object)this.amazonS3.putObject((PutObjectRequest)Matchers.anyObject())).thenAnswer((Answer)new Answer<Object>(){

            public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
                InputStream inputStream = ((PutObjectRequest)invocationOnMock.getArguments()[0]).getInputStream();
                ByteHolder byteHolder = new ByteHolder();
                byteHolder.setBytes(ByteStreams.toByteArray((InputStream)inputStream));
                ExampleRecordHandlerTest.this.mockS3Storage.add(byteHolder);
                return Mockito.mock(PutObjectResult.class);
            }
        });
        Mockito.when((Object)this.amazonS3.getObject(Matchers.anyString(), Matchers.anyString())).thenAnswer((Answer)new Answer<Object>(){

            public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
                S3Object mockObject = (S3Object)Mockito.mock(S3Object.class);
                ByteHolder byteHolder = ExampleRecordHandlerTest.this.mockS3Storage.get(0);
                ExampleRecordHandlerTest.this.mockS3Storage.remove(0);
                Mockito.when((Object)mockObject.getObjectContent()).thenReturn((Object)new S3ObjectInputStream((InputStream)new ByteArrayInputStream(byteHolder.getBytes()), null));
                return mockObject;
            }
        });
        this.recordService = new LocalHandler(this.allocator, this.amazonS3, this.awsSecretsManager, this.athena);
        this.spillReader = new S3BlockSpillReader(this.amazonS3, (BlockAllocator)this.allocator);
        logger.info("setUpBefore - exit");
    }

    @After
    public void after() {
        this.allocator.close();
    }

    @Test
    public void doReadRecordsNoSpill() {
        logger.info("doReadRecordsNoSpill: enter");
        for (int i = 0; i < 2; ++i) {
            EncryptionKey encryptionKey = i % 2 == 0 ? this.keyFactory.create() : null;
            logger.info("doReadRecordsNoSpill: Using encryptionKey[" + encryptionKey + "]");
            HashMap<String, SortedRangeSet> constraintsMap = new HashMap<String, SortedRangeSet>();
            constraintsMap.put("col3", SortedRangeSet.copyOf((ArrowType)Types.MinorType.FLOAT8.getType(), (List)ImmutableList.of((Object)Range.equal((BlockAllocator)this.allocator, (ArrowType)Types.MinorType.FLOAT8.getType(), (Object)22.0)), (boolean)false));
            ReadRecordsRequest request = new ReadRecordsRequest(IdentityUtil.fakeIdentity(), "catalog", "queryId-" + System.currentTimeMillis(), new TableName("schema", "table"), this.schemaForRead, Split.newBuilder((SpillLocation)this.makeSpillLocation(), (EncryptionKey)encryptionKey).add("year", "10").add("month", "10").add("day", "10").build(), new Constraints(constraintsMap), 100000000000L, 100000000000L);
            ObjectMapperUtil.assertSerialization(request);
            RecordResponse rawResponse = this.recordService.readRecords((RecordRequest)request);
            ObjectMapperUtil.assertSerialization(rawResponse);
            Assert.assertTrue((boolean)(rawResponse instanceof ReadRecordsResponse));
            ReadRecordsResponse response = (ReadRecordsResponse)rawResponse;
            logger.info("doReadRecordsNoSpill: rows[{}]", (Object)response.getRecordCount());
            Assert.assertTrue((response.getRecords().getRowCount() == 1 ? 1 : 0) != 0);
            logger.info("doReadRecordsNoSpill: {}", (Object)BlockUtils.rowToString((Block)response.getRecords(), (int)0));
        }
        logger.info("doReadRecordsNoSpill: exit");
    }

    @Test
    public void doReadRecordsSpill() throws Exception {
        logger.info("doReadRecordsSpill: enter");
        for (int i = 0; i < 2; ++i) {
            EncryptionKey encryptionKey = i % 2 == 0 ? this.keyFactory.create() : null;
            logger.info("doReadRecordsSpill: Using encryptionKey[" + encryptionKey + "]");
            HashMap<String, Object> constraintsMap = new HashMap<String, Object>();
            constraintsMap.put("col3", SortedRangeSet.copyOf((ArrowType)Types.MinorType.FLOAT8.getType(), (List)ImmutableList.of((Object)Range.greaterThan((BlockAllocator)this.allocator, (ArrowType)Types.MinorType.FLOAT8.getType(), (Object)-10000.0)), (boolean)false));
            constraintsMap.put("unknown", EquatableValueSet.newBuilder((BlockAllocator)this.allocator, (ArrowType)Types.MinorType.FLOAT8.getType(), (boolean)false, (boolean)true).add((Object)1.1).build());
            constraintsMap.put("unknown2", new AllOrNoneValueSet(Types.MinorType.FLOAT8.getType(), false, true));
            ReadRecordsRequest request = new ReadRecordsRequest(IdentityUtil.fakeIdentity(), "catalog", "queryId-" + System.currentTimeMillis(), new TableName("schema", "table"), this.schemaForRead, Split.newBuilder((SpillLocation)this.makeSpillLocation(), (EncryptionKey)encryptionKey).add("year", "10").add("month", "10").add("day", "10").build(), new Constraints(constraintsMap), 1600000L, 1000L);
            ObjectMapperUtil.assertSerialization(request);
            RecordResponse rawResponse = this.recordService.readRecords((RecordRequest)request);
            ObjectMapperUtil.assertSerialization(rawResponse);
            Assert.assertTrue((boolean)(rawResponse instanceof RemoteReadRecordsResponse));
            try (RemoteReadRecordsResponse response = (RemoteReadRecordsResponse)rawResponse;){
                logger.info("doReadRecordsSpill: remoteBlocks[{}]", (Object)response.getRemoteBlocks().size());
                Assert.assertTrue((response.getNumberBlocks() > 1 ? 1 : 0) != 0);
                int blockNum = 0;
                for (SpillLocation next : response.getRemoteBlocks()) {
                    S3SpillLocation spillLocation = (S3SpillLocation)next;
                    Block block = this.spillReader.read(spillLocation, response.getEncryptionKey(), response.getSchema());
                    try {
                        logger.info("doReadRecordsSpill: blockNum[{}] and recordCount[{}]", (Object)blockNum++, (Object)block.getRowCount());
                        logger.info("doReadRecordsSpill: {}", (Object)BlockUtils.rowToString((Block)block, (int)0));
                        Assert.assertNotNull((Object)BlockUtils.rowToString((Block)block, (int)0));
                    }
                    finally {
                        if (block == null) continue;
                        block.close();
                    }
                }
                continue;
            }
        }
        logger.info("doReadRecordsSpill: exit");
    }

    private SpillLocation makeSpillLocation() {
        return S3SpillLocation.newBuilder().withBucket("athena-virtuoso-test").withPrefix("lambda-spill").withQueryId(UUID.randomUUID().toString()).withSplitId(UUID.randomUUID().toString()).withIsDirectory(true).build();
    }

    private class ByteHolder {
        private byte[] bytes;

        private ByteHolder() {
        }

        public void setBytes(byte[] bytes) {
            this.bytes = bytes;
        }

        public byte[] getBytes() {
            return this.bytes;
        }
    }

    private static class LocalHandler
    implements RecordService {
        private ExampleRecordHandler handler;
        private final BlockAllocatorImpl allocator;

        public LocalHandler(BlockAllocatorImpl allocator, AmazonS3 amazonS3, AWSSecretsManager secretsManager, AmazonAthena athena) {
            this.handler = new ExampleRecordHandler(amazonS3, secretsManager, athena);
            this.handler.setNumRows(20000);
            this.allocator = allocator;
        }

        public RecordResponse readRecords(RecordRequest request) {
            try {
                switch (request.getRequestType()) {
                    case READ_RECORDS: {
                        ReadRecordsRequest req = (ReadRecordsRequest)request;
                        RecordResponse response = this.handler.doReadRecords((BlockAllocator)this.allocator, req);
                        return response;
                    }
                }
                throw new RuntimeException("Unknown request type " + request.getRequestType());
            }
            catch (Exception ex) {
                throw new RuntimeException(ex);
            }
        }
    }
}

