/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.athena.connector.lambda.handlers;

import com.amazonaws.athena.connector.lambda.data.Block;
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl;
import com.amazonaws.athena.connector.lambda.data.BlockUtils;
import com.amazonaws.athena.connector.lambda.data.FieldBuilder;
import com.amazonaws.athena.connector.lambda.data.FieldResolver;
import com.amazonaws.athena.connector.lambda.data.UnitTestBlockUtils;
import com.amazonaws.athena.connector.lambda.handlers.UserDefinedFunctionHandler;
import com.amazonaws.athena.connector.lambda.metadata.ListSchemasRequest;
import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionRequest;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionResponse;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import junit.framework.TestCase;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.reader.FieldReader;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class UserDefinedFunctionHandlerTest {
    private static final String COLUMN_PREFIX = "col_";
    private TestUserDefinedFunctionHandler handler;
    private BlockAllocatorImpl allocator;

    @Before
    public void setUp() {
        this.handler = new TestUserDefinedFunctionHandler();
        this.allocator = new BlockAllocatorImpl();
    }

    @After
    public void tearDown() {
        this.allocator.close();
    }

    @Test
    public void testInvocationWithBasicType() throws Exception {
        int rowCount = 20;
        UserDefinedFunctionRequest udfRequest = this.createUDFRequest(rowCount, Integer.class, "test_scalar_udf", true, Integer.class, Integer.class);
        UserDefinedFunctionResponse udfResponse = this.handler.processFunction((BlockAllocator)this.allocator, udfRequest);
        Block responseBlock = udfResponse.getRecords();
        Assert.assertEquals((long)1L, (long)responseBlock.getFieldReaders().size());
        Assert.assertEquals((long)rowCount, (long)responseBlock.getRowCount());
        FieldReader fieldReader = (FieldReader)responseBlock.getFieldReaders().get(0);
        for (int pos = 0; pos < rowCount; ++pos) {
            fieldReader.setPosition(pos);
            int val = (Integer)UnitTestBlockUtils.getValue(fieldReader, pos);
            int expected = this.handler.test_scalar_udf(pos + 100, pos + 100);
            Assert.assertEquals((long)expected, (long)val);
        }
    }

    @Test
    public void testInvocationWithListType() throws Exception {
        int rowCount = 20;
        UserDefinedFunctionRequest udfRequest = this.createUDFRequest(rowCount, List.class, "test_list_type", true, List.class);
        UserDefinedFunctionResponse udfResponse = this.handler.processFunction((BlockAllocator)this.allocator, udfRequest);
        Block responseBlock = udfResponse.getRecords();
        Assert.assertEquals((long)1L, (long)responseBlock.getFieldReaders().size());
        Assert.assertEquals((long)rowCount, (long)responseBlock.getRowCount());
        FieldReader fieldReader = (FieldReader)responseBlock.getFieldReaders().get(0);
        for (int pos = 0; pos < rowCount; ++pos) {
            fieldReader.setPosition(pos);
            List result = (List)UnitTestBlockUtils.getValue(fieldReader, pos);
            List<Integer> expected = this.handler.test_list_type((List<Integer>)ImmutableList.of((Object)(pos + 100), (Object)(pos + 200), (Object)(pos + 300)));
            Assert.assertArrayEquals((Object[])expected.toArray(), (Object[])result.toArray());
        }
    }

    @Test
    public void testInvocationWithStructType() throws Exception {
        int rowCount = 20;
        UserDefinedFunctionRequest udfRequest = this.createUDFRequest(rowCount, Map.class, "test_row_type", true, Map.class);
        UserDefinedFunctionResponse udfResponse = this.handler.processFunction((BlockAllocator)this.allocator, udfRequest);
        Block responseBlock = udfResponse.getRecords();
        Assert.assertEquals((long)1L, (long)responseBlock.getFieldReaders().size());
        Assert.assertEquals((long)rowCount, (long)responseBlock.getRowCount());
        FieldReader fieldReader = (FieldReader)responseBlock.getFieldReaders().get(0);
        for (int pos = 0; pos < rowCount; ++pos) {
            fieldReader.setPosition(pos);
            Map actual = (Map)UnitTestBlockUtils.getValue(fieldReader, pos);
            ImmutableMap input = ImmutableMap.of((Object)"intVal", (Object)(pos + 100), (Object)"doubleVal", (Object)((double)pos + 200.2));
            Map<String, Object> expected = this.handler.test_row_type((Map<String, Object>)input);
            for (Map.Entry<String, Object> entry : expected.entrySet()) {
                String key = entry.getKey();
                TestCase.assertTrue((boolean)actual.containsKey(key));
                Assert.assertEquals((Object)expected.get(key), actual.get(key));
            }
        }
    }

    @Test
    public void testInvocationWithNullVAlue() throws Exception {
        int rowCount = 20;
        UserDefinedFunctionRequest udfRequest = this.createUDFRequest(rowCount, Boolean.class, "test_scalar_function_with_null_value", false, Integer.class);
        UserDefinedFunctionResponse udfResponse = this.handler.processFunction((BlockAllocator)this.allocator, udfRequest);
        Block responseBlock = udfResponse.getRecords();
        Assert.assertEquals((long)1L, (long)responseBlock.getFieldReaders().size());
        Assert.assertEquals((long)rowCount, (long)responseBlock.getRowCount());
        FieldReader fieldReader = (FieldReader)responseBlock.getFieldReaders().get(0);
        for (int pos = 0; pos < rowCount; ++pos) {
            fieldReader.setPosition(pos);
            TestCase.assertTrue((boolean)fieldReader.isSet());
            Boolean expected = this.handler.test_scalar_function_with_null_value(null);
            Boolean actual = fieldReader.readBoolean();
            Assert.assertEquals((Object)expected, (Object)actual);
        }
    }

    @Test
    public void testRequestTypeValidation() throws Exception {
        ListSchemasRequest federationRequest = new ListSchemasRequest(null, "dummy_catalog", "dummy_qid");
        ObjectMapper objectMapper = VersionedObjectMapperFactory.create((BlockAllocator)this.allocator);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        objectMapper.writeValue((OutputStream)byteArrayOutputStream, (Object)federationRequest);
        byte[] inputData = byteArrayOutputStream.toByteArray();
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(inputData);
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        try {
            this.handler.handleRequest(byteArrayInputStream, outputStream, null);
            Assert.fail();
        }
        catch (Exception e) {
            TestCase.assertTrue((boolean)e.getMessage().contains("Expected a UserDefinedFunctionRequest but found"));
        }
    }

    @Test
    public void testMethodNotFound() {
        int rowCount = 20;
        UserDefinedFunctionRequest udfRequest = this.createUDFRequest(rowCount, Integer.class, "method_that_does_not_exsit", true, Integer.class, Integer.class);
        try {
            UserDefinedFunctionResponse udfResponse = this.handler.processFunction((BlockAllocator)this.allocator, udfRequest);
            Assert.fail((String)"Expected function to fail due to method not found, but succeeded.");
        }
        catch (Exception e) {
            TestCase.assertTrue((boolean)(e.getCause() instanceof NoSuchMethodException));
        }
    }

    private UserDefinedFunctionRequest createUDFRequest(int rowCount, Class returnType, String methodName, boolean nonNullData, Class ... argumentTypes) {
        Schema inputSchema = this.buildSchema(argumentTypes);
        Schema outputSchema = this.buildSchema(returnType);
        Block block = this.allocator.createBlock(inputSchema);
        block.setRowCount(rowCount);
        if (nonNullData) {
            this.writeData(block, rowCount);
        }
        return new UserDefinedFunctionRequest(null, block, outputSchema, methodName, UserDefinedFunctionType.SCALAR);
    }

    private void writeData(Block block, int numOfRows) {
        for (FieldVector fieldVector : block.getFieldVectors()) {
            fieldVector.setInitialCapacity(numOfRows);
            fieldVector.allocateNew();
            fieldVector.setValueCount(numOfRows);
            for (int idx = 0; idx < numOfRows; ++idx) {
                this.writeColumn(fieldVector, idx);
            }
        }
    }

    private void writeColumn(FieldVector fieldVector, int idx) {
        if (fieldVector instanceof IntVector) {
            IntVector intVector = (IntVector)fieldVector;
            intVector.setSafe(idx, idx + 100);
            return;
        }
        if (fieldVector instanceof Float4Vector) {
            Float4Vector float4Vector = (Float4Vector)fieldVector;
            float4Vector.setSafe(idx, (float)idx + 100.1f);
            return;
        }
        if (fieldVector instanceof Float8Vector) {
            Float8Vector float8Vector = (Float8Vector)fieldVector;
            float8Vector.setSafe(idx, (double)idx + 100.2);
            return;
        }
        if (fieldVector instanceof VarCharVector) {
            VarCharVector varCharVector = (VarCharVector)fieldVector;
            varCharVector.setSafe(idx, new Text(idx + "-my-varchar"));
            return;
        }
        if (fieldVector instanceof ListVector) {
            BlockUtils.setComplexValue((FieldVector)fieldVector, (int)idx, (FieldResolver)FieldResolver.DEFAULT, (Object)ImmutableList.of((Object)(idx + 100), (Object)(idx + 200), (Object)(idx + 300)));
            return;
        }
        if (fieldVector instanceof StructVector) {
            ImmutableMap input = ImmutableMap.of((Object)"intVal", (Object)(idx + 100), (Object)"doubleVal", (Object)((double)idx + 200.2));
            BlockUtils.setComplexValue((FieldVector)fieldVector, (int)idx, (FieldResolver)FieldResolver.DEFAULT, (Object)input);
            return;
        }
        throw new IllegalArgumentException("Unsupported fieldVector " + fieldVector.getClass().getCanonicalName());
    }

    private Schema buildSchema(Class ... types) {
        ImmutableList.Builder fieldsBuilder = ImmutableList.builder();
        for (int i = 0; i < types.length; ++i) {
            String columnName = COLUMN_PREFIX + i;
            Field field = this.getArrowField(types[i], columnName);
            fieldsBuilder.add((Object)field);
        }
        return new Schema((Iterable)fieldsBuilder.build(), null);
    }

    private Field getArrowField(Class type, String columnName) {
        if (type == Integer.class) {
            return new Field(columnName, FieldType.nullable((ArrowType)new ArrowType.Int(32, true)), null);
        }
        if (type == Float.class) {
            return new Field(columnName, FieldType.nullable((ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null);
        }
        if (type == Double.class) {
            return new Field(columnName, FieldType.nullable((ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null);
        }
        if (type == String.class) {
            return new Field(columnName, FieldType.nullable((ArrowType)new ArrowType.Utf8()), null);
        }
        if (type == Boolean.class) {
            return new Field(columnName, FieldType.nullable((ArrowType)new ArrowType.Bool()), null);
        }
        if (type == List.class) {
            Field childField = new Field(columnName, FieldType.nullable((ArrowType)new ArrowType.Int(32, true)), null);
            return new Field(columnName, FieldType.nullable((ArrowType)Types.MinorType.LIST.getType()), Collections.singletonList(childField));
        }
        if (type == Map.class) {
            FieldBuilder fieldBuilder = FieldBuilder.newBuilder((String)columnName, (ArrowType)Types.MinorType.STRUCT.getType());
            Field childField1 = new Field("intVal", FieldType.nullable((ArrowType)new ArrowType.Int(32, true)), null);
            Field childField2 = new Field("doubleVal", FieldType.nullable((ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null);
            fieldBuilder.addField(childField1);
            fieldBuilder.addField(childField2);
            return fieldBuilder.build();
        }
        throw new IllegalArgumentException("Unsupported type " + type);
    }

    private static class TestUserDefinedFunctionHandler
    extends UserDefinedFunctionHandler {
        public TestUserDefinedFunctionHandler() {
            super("test_type");
        }

        public Integer test_scalar_udf(Integer col1, Integer col2) {
            return col1 + col2;
        }

        public Boolean test_scalar_function_with_null_value(Integer col1) {
            if (col1 == null) {
                return true;
            }
            return false;
        }

        public List<Integer> test_list_type(List<Integer> input) {
            return input.stream().map(val -> val + 1).collect(Collectors.toList());
        }

        public Map<String, Object> test_row_type(Map<String, Object> input) {
            Integer intVal = (Integer)input.get("intVal");
            Double doubleVal = (Double)input.get("doubleVal");
            return ImmutableMap.of((Object)"intVal", (Object)(intVal + 1), (Object)"doubleVal", (Object)(doubleVal + 1.0));
        }
    }
}

