package io.airlift.drift.transport.netty.server;

import com.google.common.base.Defaults;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Primitives;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import io.airlift.drift.TApplicationException;
import io.airlift.drift.codec.ThriftCodec;
import io.airlift.drift.codec.internal.ProtocolReader;
import io.airlift.drift.codec.internal.ProtocolWriter;
import io.airlift.drift.protocol.TMessage;
import io.airlift.drift.protocol.TProtocol;
import io.airlift.drift.protocol.TProtocolReader;
import io.airlift.drift.protocol.TProtocolWriter;
import io.airlift.drift.protocol.TTransport;
import io.airlift.drift.transport.MethodMetadata;
import io.airlift.drift.transport.ParameterMetadata;
import io.airlift.drift.transport.netty.codec.Protocol;
import io.airlift.drift.transport.netty.codec.ThriftFrame;
import io.airlift.drift.transport.netty.codec.Transport;
import io.airlift.drift.transport.netty.ssl.TChannelBufferInputTransport;
import io.airlift.drift.transport.netty.ssl.TChannelBufferOutputTransport;
import io.airlift.drift.transport.server.ServerInvokeRequest;
import io.airlift.drift.transport.server.ServerMethodInvoker;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

/* loaded from: input_file:io/airlift/drift/transport/netty/server/ThriftServerHandler.class */
public class ThriftServerHandler extends ChannelDuplexHandler {
    private static final Logger log = Logger.get(ThriftServerHandler.class);
    private final ServerMethodInvoker methodInvoker;
    private final ScheduledExecutorService timeoutExecutor;
    private final Duration requestTimeout;

    public ThriftServerHandler(ServerMethodInvoker serverMethodInvoker, Duration duration, ScheduledExecutorService scheduledExecutorService) {
        this.methodInvoker = (ServerMethodInvoker) Objects.requireNonNull(serverMethodInvoker, "methodInvoker is null");
        this.requestTimeout = (Duration) Objects.requireNonNull(duration, "requestTimeout is null");
        this.timeoutExecutor = (ScheduledExecutorService) Objects.requireNonNull(scheduledExecutorService, "timeoutExecutor is null");
    }

    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) {
        if (obj instanceof ThriftFrame) {
            messageReceived(channelHandlerContext, (ThriftFrame) obj);
        } else {
            channelHandlerContext.fireChannelRead(obj);
        }
    }

    private void messageReceived(final ChannelHandlerContext channelHandlerContext, ThriftFrame thriftFrame) {
        TChannelBufferInputTransport tChannelBufferInputTransport = new TChannelBufferInputTransport(thriftFrame.getMessage());
        try {
            try {
                try {
                    Futures.addCallback(decodeMessage(channelHandlerContext, tChannelBufferInputTransport, thriftFrame.getTransport(), thriftFrame.getProtocol(), thriftFrame.getHeaders(), thriftFrame.isSupportOutOfOrderResponse()), new FutureCallback<ThriftFrame>() { // from class: io.airlift.drift.transport.netty.server.ThriftServerHandler.1
                        public void onSuccess(ThriftFrame thriftFrame2) {
                            channelHandlerContext.writeAndFlush(thriftFrame2);
                        }

                        public void onFailure(Throwable th) {
                            channelHandlerContext.disconnect();
                        }
                    }, MoreExecutors.directExecutor());
                    tChannelBufferInputTransport.release();
                    thriftFrame.release();
                } catch (Throwable th) {
                    log.error(th, "Error processing request");
                    channelHandlerContext.disconnect();
                    throw th;
                }
            } catch (Exception e) {
                log.error(e, "Exception processing request");
                channelHandlerContext.disconnect();
                tChannelBufferInputTransport.release();
                thriftFrame.release();
            }
        } catch (Throwable th2) {
            tChannelBufferInputTransport.release();
            thriftFrame.release();
            throw th2;
        }
    }

    private ListenableFuture<ThriftFrame> decodeMessage(ChannelHandlerContext channelHandlerContext, TTransport tTransport, Transport transport, Protocol protocol, Map<String, String> map, boolean z) throws Exception {
        long nanoTime = System.nanoTime();
        TProtocol createProtocol = protocol.createProtocol(tTransport);
        TMessage readMessageBegin = createProtocol.readMessageBegin();
        Optional methodMetadata = this.methodInvoker.getMethodMetadata(readMessageBegin.getName());
        if (!methodMetadata.isPresent()) {
            return Futures.immediateFuture(writeApplicationException(channelHandlerContext, readMessageBegin.getName(), transport, protocol, readMessageBegin.getSequenceId(), z, TApplicationException.Type.UNKNOWN_METHOD, "Invalid method name: '" + readMessageBegin.getName() + "'", null));
        }
        MethodMetadata methodMetadata2 = (MethodMetadata) methodMetadata.get();
        if (readMessageBegin.getType() != 1 && readMessageBegin.getType() != 4) {
            return Futures.immediateFuture(writeApplicationException(channelHandlerContext, readMessageBegin.getName(), transport, protocol, readMessageBegin.getSequenceId(), z, TApplicationException.Type.INVALID_MESSAGE_TYPE, "Invalid method message type: '" + ((int) readMessageBegin.getType()) + "'", null));
        }
        ListenableFuture invoke = this.methodInvoker.invoke(new ServerInvokeRequest(methodMetadata2, map, readArguments(methodMetadata2, createProtocol)));
        this.methodInvoker.recordResult(readMessageBegin.getName(), nanoTime, invoke);
        return FluentFuture.from(invoke).transformAsync(obj -> {
            try {
                return Futures.immediateFuture(writeSuccessResponse(channelHandlerContext, methodMetadata2, transport, protocol, readMessageBegin.getSequenceId(), z, obj));
            } catch (Exception e) {
                return Futures.immediateFailedFuture(e);
            }
        }, MoreExecutors.directExecutor()).withTimeout(this.requestTimeout.toMillis(), TimeUnit.MILLISECONDS, this.timeoutExecutor).catchingAsync(Exception.class, exc -> {
            try {
                return Futures.immediateFuture(writeExceptionResponse(channelHandlerContext, methodMetadata2, transport, protocol, readMessageBegin.getSequenceId(), z, exc));
            } catch (Exception e) {
                return Futures.immediateFailedFuture(e);
            }
        }, MoreExecutors.directExecutor());
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v40, types: [java.util.OptionalDouble] */
    /* JADX WARN: Type inference failed for: r0v41, types: [java.util.OptionalLong] */
    /* JADX WARN: Type inference failed for: r0v42, types: [java.util.OptionalInt] */
    /* JADX WARN: Type inference failed for: r0v45, types: [java.lang.Object] */
    private static Map<Short, Object> readArguments(MethodMetadata methodMetadata, TProtocolReader tProtocolReader) throws Exception {
        HashMap hashMap = new HashMap(methodMetadata.getParameters().size());
        ProtocolReader protocolReader = new ProtocolReader(tProtocolReader);
        protocolReader.readStructBegin();
        while (protocolReader.nextField()) {
            short fieldId = protocolReader.getFieldId();
            ParameterMetadata parameterByFieldId = methodMetadata.getParameterByFieldId(fieldId);
            if (parameterByFieldId == null) {
                protocolReader.skipFieldData();
            } else {
                hashMap.put(Short.valueOf(fieldId), protocolReader.readField(parameterByFieldId.getCodec()));
            }
        }
        protocolReader.readStructEnd();
        for (ParameterMetadata parameterMetadata : methodMetadata.getParameters()) {
            if (!hashMap.containsKey(Short.valueOf(parameterMetadata.getFieldId()))) {
                Type javaType = parameterMetadata.getCodec().getType().getJavaType();
                Optional optional = null;
                if (javaType instanceof Class) {
                    Class cls = (Class) javaType;
                    if (cls.isPrimitive()) {
                        optional = Defaults.defaultValue(Primitives.unwrap(cls));
                    } else if (cls == OptionalInt.class) {
                        optional = OptionalInt.empty();
                    } else if (cls == OptionalLong.class) {
                        optional = OptionalLong.empty();
                    } else if (cls == OptionalDouble.class) {
                        optional = OptionalDouble.empty();
                    }
                } else if ((javaType instanceof ParameterizedType) && ((ParameterizedType) javaType).getRawType().equals(Optional.class)) {
                    optional = Optional.empty();
                }
                hashMap.put(Short.valueOf(parameterMetadata.getFieldId()), optional);
            }
        }
        return hashMap;
    }

    private static ThriftFrame writeSuccessResponse(ChannelHandlerContext channelHandlerContext, MethodMetadata methodMetadata, Transport transport, Protocol protocol, int i, boolean z, Object obj) throws Exception {
        TChannelBufferOutputTransport tChannelBufferOutputTransport = new TChannelBufferOutputTransport(channelHandlerContext.alloc());
        try {
            writeResponse(methodMetadata.getName(), protocol.createProtocol(tChannelBufferOutputTransport), i, "success", (short) 0, methodMetadata.getResultCodec(), obj);
            ThriftFrame thriftFrame = new ThriftFrame(i, tChannelBufferOutputTransport.getBuffer(), ImmutableMap.of(), transport, protocol, z);
            tChannelBufferOutputTransport.release();
            return thriftFrame;
        } catch (Throwable th) {
            tChannelBufferOutputTransport.release();
            throw th;
        }
    }

    private static ThriftFrame writeExceptionResponse(ChannelHandlerContext channelHandlerContext, MethodMetadata methodMetadata, Transport transport, Protocol protocol, int i, boolean z, Throwable th) throws Exception {
        Optional exceptionId = methodMetadata.getExceptionId(th.getClass());
        if (!exceptionId.isPresent()) {
            return writeApplicationException(channelHandlerContext, methodMetadata.getName(), transport, protocol, i, z, TApplicationException.Type.INTERNAL_ERROR, "Internal error processing " + methodMetadata.getName() + ": " + th.getMessage(), th);
        }
        TChannelBufferOutputTransport tChannelBufferOutputTransport = new TChannelBufferOutputTransport(channelHandlerContext.alloc());
        try {
            writeResponse(methodMetadata.getName(), protocol.createProtocol(tChannelBufferOutputTransport), i, "exception", ((Short) exceptionId.get()).shortValue(), (ThriftCodec) methodMetadata.getExceptionCodecs().get(exceptionId.get()), th);
            ThriftFrame thriftFrame = new ThriftFrame(i, tChannelBufferOutputTransport.getBuffer(), ImmutableMap.of(), transport, protocol, z);
            tChannelBufferOutputTransport.release();
            return thriftFrame;
        } catch (Throwable th2) {
            tChannelBufferOutputTransport.release();
            throw th2;
        }
    }

    private static ThriftFrame writeApplicationException(ChannelHandlerContext channelHandlerContext, String str, Transport transport, Protocol protocol, int i, boolean z, TApplicationException.Type type, String str2, Throwable th) throws Exception {
        TApplicationException tApplicationException = new TApplicationException(type, str2);
        if (th != null) {
            tApplicationException.initCause(th);
        }
        TChannelBufferOutputTransport tChannelBufferOutputTransport = new TChannelBufferOutputTransport(channelHandlerContext.alloc());
        try {
            TProtocol createProtocol = protocol.createProtocol(tChannelBufferOutputTransport);
            createProtocol.writeMessageBegin(new TMessage(str, (byte) 3, i));
            ExceptionWriter.writeTApplicationException(tApplicationException, createProtocol);
            createProtocol.writeMessageEnd();
            ThriftFrame thriftFrame = new ThriftFrame(i, tChannelBufferOutputTransport.getBuffer(), ImmutableMap.of(), transport, protocol, z);
            tChannelBufferOutputTransport.release();
            return thriftFrame;
        } catch (Throwable th2) {
            tChannelBufferOutputTransport.release();
            throw th2;
        }
    }

    private static void writeResponse(String str, TProtocolWriter tProtocolWriter, int i, String str2, short s, ThriftCodec<Object> thriftCodec, Object obj) throws Exception {
        tProtocolWriter.writeMessageBegin(new TMessage(str, (byte) 2, i));
        ProtocolWriter protocolWriter = new ProtocolWriter(tProtocolWriter);
        protocolWriter.writeStructBegin(str + "_result");
        protocolWriter.writeField(str2, s, thriftCodec, obj);
        protocolWriter.writeStructEnd();
        tProtocolWriter.writeMessageEnd();
    }
}
