package com.amazonaws.athena.connector.lambda.handlers;

import com.amazonaws.athena.connector.lambda.data.Block;
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl;
import com.amazonaws.athena.connector.lambda.data.BlockUtils;
import com.amazonaws.athena.connector.lambda.data.FieldResolver;
import com.amazonaws.athena.connector.lambda.data.projectors.ArrowValueProjector;
import com.amazonaws.athena.connector.lambda.data.projectors.ProjectorUtils;
import com.amazonaws.athena.connector.lambda.data.writers.GeneratedRowWriter;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.Extractor;
import com.amazonaws.athena.connector.lambda.data.writers.fieldwriters.FieldWriterFactory;
import com.amazonaws.athena.connector.lambda.request.FederationRequest;
import com.amazonaws.athena.connector.lambda.request.FederationResponse;
import com.amazonaws.athena.connector.lambda.request.PingRequest;
import com.amazonaws.athena.connector.lambda.request.PingResponse;
import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionRequest;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionResponse;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionType;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/amazonaws/athena/connector/lambda/handlers/UserDefinedFunctionHandler.class */
public abstract class UserDefinedFunctionHandler implements RequestStreamHandler {
    private static final Logger logger = LoggerFactory.getLogger(UserDefinedFunctionHandler.class);
    private static final int RETURN_COLUMN_COUNT = 1;
    private final String sourceType;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.amazonaws.athena.connector.lambda.handlers.UserDefinedFunctionHandler$1, reason: invalid class name */
    /* loaded from: input_file:com/amazonaws/athena/connector/lambda/handlers/UserDefinedFunctionHandler$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$amazonaws$athena$connector$lambda$udf$UserDefinedFunctionType;
        static final /* synthetic */ int[] $SwitchMap$org$apache$arrow$vector$types$Types$MinorType = new int[Types.MinorType.values().length];

        static {
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.INT.ordinal()] = UserDefinedFunctionHandler.RETURN_COLUMN_COUNT;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.DATEMILLI.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.DATEDAY.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.TINYINT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.SMALLINT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.FLOAT4.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.FLOAT8.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.DECIMAL.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.BIT.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.BIGINT.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.VARCHAR.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.VARBINARY.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.LIST.ordinal()] = 13;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$apache$arrow$vector$types$Types$MinorType[Types.MinorType.STRUCT.ordinal()] = 14;
            } catch (NoSuchFieldError e14) {
            }
            $SwitchMap$com$amazonaws$athena$connector$lambda$udf$UserDefinedFunctionType = new int[UserDefinedFunctionType.values().length];
            try {
                $SwitchMap$com$amazonaws$athena$connector$lambda$udf$UserDefinedFunctionType[UserDefinedFunctionType.SCALAR.ordinal()] = UserDefinedFunctionHandler.RETURN_COLUMN_COUNT;
            } catch (NoSuchFieldError e15) {
            }
        }
    }

    public UserDefinedFunctionHandler(String str) {
        this.sourceType = str;
    }

    public final void handleRequest(InputStream inputStream, OutputStream outputStream, Context context) {
        BlockAllocatorImpl blockAllocatorImpl = new BlockAllocatorImpl();
        try {
            ObjectMapper create = VersionedObjectMapperFactory.create(blockAllocatorImpl);
            try {
                FederationRequest federationRequest = (FederationRequest) create.readValue(inputStream, FederationRequest.class);
                try {
                    if (!(federationRequest instanceof PingRequest)) {
                        if (!(federationRequest instanceof UserDefinedFunctionRequest)) {
                            throw new RuntimeException("Expected a UserDefinedFunctionRequest but found " + federationRequest.getClass());
                        }
                        doHandleRequest(blockAllocatorImpl, create, (UserDefinedFunctionRequest) federationRequest, outputStream);
                        if (federationRequest != null) {
                            federationRequest.close();
                        }
                        blockAllocatorImpl.close();
                        return;
                    }
                    PingResponse doPing = doPing((PingRequest) federationRequest);
                    try {
                        assertNotNull(doPing);
                        create.writeValue(outputStream, doPing);
                        if (doPing != null) {
                            doPing.close();
                        }
                        if (federationRequest != null) {
                            federationRequest.close();
                        }
                        blockAllocatorImpl.close();
                    } catch (Throwable th) {
                        if (doPing != null) {
                            try {
                                doPing.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (Throwable th3) {
                    if (federationRequest != null) {
                        try {
                            federationRequest.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } catch (Exception e) {
                if (!(e instanceof RuntimeException)) {
                    throw new RuntimeException(e);
                }
            }
        } catch (Throwable th5) {
            try {
                blockAllocatorImpl.close();
            } catch (Throwable th6) {
                th5.addSuppressed(th6);
            }
            throw th5;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void doHandleRequest(BlockAllocator blockAllocator, ObjectMapper objectMapper, UserDefinedFunctionRequest userDefinedFunctionRequest, OutputStream outputStream) throws Exception {
        logger.info("doHandleRequest: request[{}]", userDefinedFunctionRequest);
        UserDefinedFunctionResponse processFunction = processFunction(blockAllocator, userDefinedFunctionRequest);
        try {
            logger.info("doHandleRequest: response[{}]", processFunction);
            assertNotNull(processFunction);
            objectMapper.writeValue(outputStream, processFunction);
            if (processFunction != null) {
                processFunction.close();
            }
        } catch (Throwable th) {
            if (processFunction != null) {
                try {
                    processFunction.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @VisibleForTesting
    UserDefinedFunctionResponse processFunction(BlockAllocator blockAllocator, UserDefinedFunctionRequest userDefinedFunctionRequest) throws Exception {
        UserDefinedFunctionType functionType = userDefinedFunctionRequest.getFunctionType();
        switch (AnonymousClass1.$SwitchMap$com$amazonaws$athena$connector$lambda$udf$UserDefinedFunctionType[functionType.ordinal()]) {
            case RETURN_COLUMN_COUNT /* 1 */:
                return processScalarFunction(blockAllocator, userDefinedFunctionRequest);
            default:
                throw new UnsupportedOperationException("Unsupported function type " + functionType);
        }
    }

    private UserDefinedFunctionResponse processScalarFunction(BlockAllocator blockAllocator, UserDefinedFunctionRequest userDefinedFunctionRequest) throws Exception {
        Method extractScalarFunctionMethod = extractScalarFunctionMethod(userDefinedFunctionRequest);
        return new UserDefinedFunctionResponse(processRows(blockAllocator, extractScalarFunctionMethod, userDefinedFunctionRequest.getInputRecords(), userDefinedFunctionRequest.getOutputSchema()), extractScalarFunctionMethod.getName());
    }

    protected Block processRows(BlockAllocator blockAllocator, Method method, Block block, Schema schema) throws Exception {
        int rowCount = block.getRowCount();
        ArrayList newArrayList = Lists.newArrayList();
        Iterator<Field> it = block.getFields().iterator();
        while (it.hasNext()) {
            newArrayList.add(ProjectorUtils.createArrowValueProjector(block.getFieldReader(it.next().getName())));
        }
        GeneratedRowWriter createOutputRowWriter = createOutputRowWriter((Field) schema.getFields().get(0), newArrayList, method);
        Block createBlock = blockAllocator.createBlock(schema);
        createBlock.setRowCount(rowCount);
        for (int i = 0; i < rowCount; i += RETURN_COLUMN_COUNT) {
            try {
                createOutputRowWriter.writeRow(createBlock, i, Integer.valueOf(i));
            } catch (Throwable th) {
                try {
                    createBlock.close();
                } catch (Exception e) {
                    logger.error("Error closing output block", e);
                }
                throw th;
            }
        }
        return createBlock;
    }

    private Method extractScalarFunctionMethod(UserDefinedFunctionRequest userDefinedFunctionRequest) {
        String methodName = userDefinedFunctionRequest.getMethodName();
        Class[] extractJavaTypes = extractJavaTypes(userDefinedFunctionRequest.getInputRecords().getSchema());
        Class[] extractJavaTypes2 = extractJavaTypes(userDefinedFunctionRequest.getOutputSchema());
        Preconditions.checkState(extractJavaTypes2.length == RETURN_COLUMN_COUNT, String.format("Expecting %d return columns, found %d in method signature.", Integer.valueOf(RETURN_COLUMN_COUNT), Integer.valueOf(extractJavaTypes2.length)));
        Class cls = extractJavaTypes2[0];
        try {
            Method method = getClass().getMethod(methodName, extractJavaTypes);
            logger.info(String.format("Found UDF method %s with input types [%s] and output types [%s]", methodName, Arrays.toString(extractJavaTypes), cls.getName()));
            if (cls.equals(method.getReturnType())) {
                return method;
            }
            throw new IllegalArgumentException("signature return type " + cls + " does not match udf implementation return type " + method.getReturnType());
        } catch (NoSuchMethodException e) {
            throw new RuntimeException("Failed to find UDF method. " + e.getMessage() + " Please make sure the method name contains only lowercase and the method signature (name and argument types) in Lambda matches the function signature defined in SQL.", e);
        }
    }

    private Class[] extractJavaTypes(Schema schema) {
        Class[] clsArr = new Class[schema.getFields().size()];
        List fields = schema.getFields();
        for (int i = 0; i < fields.size(); i += RETURN_COLUMN_COUNT) {
            clsArr[i] = BlockUtils.getJavaType(Types.getMinorTypeForArrowType(((Field) fields.get(i)).getType()));
        }
        return clsArr;
    }

    private final PingResponse doPing(PingRequest pingRequest) {
        PingResponse pingResponse = new PingResponse(pingRequest.getCatalogName(), pingRequest.getQueryId(), this.sourceType, 24, 2);
        try {
            onPing(pingRequest);
        } catch (Exception e) {
            logger.warn("doPing: encountered an exception while delegating onPing.", e);
        }
        return pingResponse;
    }

    protected void onPing(PingRequest pingRequest) {
    }

    private void assertNotNull(FederationResponse federationResponse) {
        if (federationResponse == null) {
            throw new RuntimeException("Response was null");
        }
    }

    private GeneratedRowWriter createOutputRowWriter(Field field, List<ArrowValueProjector> list, Method method) {
        GeneratedRowWriter.RowWriterBuilder newBuilder = GeneratedRowWriter.newBuilder();
        Extractor makeExtractor = makeExtractor(field, list, method);
        if (makeExtractor != null) {
            newBuilder.withExtractor(field.getName(), makeExtractor);
        } else {
            newBuilder.withFieldWriterFactory(field.getName(), makeFactory(field, list, method));
        }
        return newBuilder.build();
    }

    private Extractor makeExtractor(Field field, List<ArrowValueProjector> list, Method method) {
        Types.MinorType minorTypeForArrowType = Types.getMinorTypeForArrowType(field.getType());
        Object[] objArr = new Object[list.size()];
        switch (AnonymousClass1.$SwitchMap$org$apache$arrow$vector$types$Types$MinorType[minorTypeForArrowType.ordinal()]) {
            case RETURN_COLUMN_COUNT /* 1 */:
                return (obj, nullableIntHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj).intValue(), list);
                    if (invokeMethod == null) {
                        nullableIntHolder.isSet = 0;
                    } else {
                        nullableIntHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableIntHolder.value = ((Integer) invokeMethod).intValue();
                    }
                };
            case SerDeVersion.SERDE_VERSION /* 2 */:
                return (obj2, nullableDateMilliHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj2).intValue(), list);
                    if (invokeMethod == null) {
                        nullableDateMilliHolder.isSet = 0;
                    } else {
                        nullableDateMilliHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableDateMilliHolder.value = ((LocalDateTime) invokeMethod).atZone(BlockUtils.UTC_ZONE_ID).toInstant().toEpochMilli();
                    }
                };
            case 3:
                return (obj3, nullableDateDayHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj3).intValue(), list);
                    if (invokeMethod == null) {
                        nullableDateDayHolder.isSet = 0;
                    } else {
                        nullableDateDayHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableDateDayHolder.value = (int) ((LocalDate) invokeMethod).toEpochDay();
                    }
                };
            case 4:
                return (obj4, nullableTinyIntHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj4).intValue(), list);
                    if (invokeMethod == null) {
                        nullableTinyIntHolder.isSet = 0;
                    } else {
                        nullableTinyIntHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableTinyIntHolder.value = ((Byte) invokeMethod).byteValue();
                    }
                };
            case 5:
                return (obj5, nullableSmallIntHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj5).intValue(), list);
                    if (invokeMethod == null) {
                        nullableSmallIntHolder.isSet = 0;
                    } else {
                        nullableSmallIntHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableSmallIntHolder.value = ((Short) invokeMethod).shortValue();
                    }
                };
            case 6:
                return (obj6, nullableFloat4Holder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj6).intValue(), list);
                    if (invokeMethod == null) {
                        nullableFloat4Holder.isSet = 0;
                    } else {
                        nullableFloat4Holder.isSet = RETURN_COLUMN_COUNT;
                        nullableFloat4Holder.value = ((Float) invokeMethod).floatValue();
                    }
                };
            case 7:
                return (obj7, nullableFloat8Holder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj7).intValue(), list);
                    if (invokeMethod == null) {
                        nullableFloat8Holder.isSet = 0;
                    } else {
                        nullableFloat8Holder.isSet = RETURN_COLUMN_COUNT;
                        nullableFloat8Holder.value = ((Double) invokeMethod).doubleValue();
                    }
                };
            case 8:
                return (obj8, nullableDecimalHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj8).intValue(), list);
                    if (invokeMethod == null) {
                        nullableDecimalHolder.isSet = 0;
                    } else {
                        nullableDecimalHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableDecimalHolder.value = (BigDecimal) invokeMethod;
                    }
                };
            case 9:
                return (obj9, nullableBitHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj9).intValue(), list);
                    if (invokeMethod == null) {
                        nullableBitHolder.isSet = 0;
                    } else {
                        nullableBitHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableBitHolder.value = ((Boolean) invokeMethod).booleanValue() ? RETURN_COLUMN_COUNT : 0;
                    }
                };
            case 10:
                return (obj10, nullableBigIntHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj10).intValue(), list);
                    if (invokeMethod == null) {
                        nullableBigIntHolder.isSet = 0;
                    } else {
                        nullableBigIntHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableBigIntHolder.value = ((Long) invokeMethod).longValue();
                    }
                };
            case 11:
                return (obj11, nullableVarCharHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj11).intValue(), list);
                    if (invokeMethod == null) {
                        nullableVarCharHolder.isSet = 0;
                    } else {
                        nullableVarCharHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableVarCharHolder.value = (String) invokeMethod;
                    }
                };
            case 12:
                return (obj12, nullableVarBinaryHolder) -> {
                    Object invokeMethod = invokeMethod(method, objArr, ((Integer) obj12).intValue(), list);
                    if (invokeMethod == null) {
                        nullableVarBinaryHolder.isSet = 0;
                    } else {
                        nullableVarBinaryHolder.isSet = RETURN_COLUMN_COUNT;
                        nullableVarBinaryHolder.value = (byte[]) invokeMethod;
                    }
                };
            default:
                return null;
        }
    }

    private FieldWriterFactory makeFactory(Field field, List<ArrowValueProjector> list, Method method) {
        Object[] objArr = new Object[list.size()];
        Types.MinorType minorTypeForArrowType = Types.getMinorTypeForArrowType(field.getType());
        switch (AnonymousClass1.$SwitchMap$org$apache$arrow$vector$types$Types$MinorType[minorTypeForArrowType.ordinal()]) {
            case 13:
            case 14:
                return (fieldVector, extractor, constraintProjector) -> {
                    return (obj, i) -> {
                        BlockUtils.setComplexValue(fieldVector, i, FieldResolver.DEFAULT, invokeMethod(method, objArr, ((Integer) obj).intValue(), list));
                        return true;
                    };
                };
            default:
                throw new IllegalArgumentException("Unsupported type " + minorTypeForArrowType);
        }
    }

    private Object invokeMethod(Method method, Object[] objArr, int i, List<ArrowValueProjector> list) {
        for (int i2 = 0; i2 < list.size(); i2 += RETURN_COLUMN_COUNT) {
            objArr[i2] = list.get(i2).project(i);
        }
        try {
            return method.invoke(this, objArr);
        } catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        } catch (IllegalArgumentException e2) {
            throw new RuntimeException(String.format("%s. Expected function types %s, got types %s", e2.getMessage(), Arrays.stream(method.getParameterTypes()).map(cls -> {
                return cls.getName();
            }).collect(Collectors.toList()), Arrays.stream(objArr).map(obj -> {
                return obj.getClass().getName();
            }).collect(Collectors.toList())), e2);
        } catch (InvocationTargetException e3) {
            if (Objects.isNull(e3)) {
                throw new RuntimeException(e3);
            }
            throw new RuntimeException(e3.getCause());
        }
    }
}
