package io.airlift.drift.transport.netty.server;

import com.google.common.base.Defaults;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Primitives;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.drift.TApplicationException;
import io.airlift.drift.codec.ThriftCodec;
import io.airlift.drift.codec.internal.ProtocolReader;
import io.airlift.drift.codec.internal.ProtocolWriter;
import io.airlift.drift.protocol.TMessage;
import io.airlift.drift.protocol.TProtocol;
import io.airlift.drift.protocol.TProtocolFactory;
import io.airlift.drift.transport.MethodMetadata;
import io.airlift.drift.transport.ParameterMetadata;
import io.airlift.drift.transport.netty.TChannelBufferInputTransport;
import io.airlift.drift.transport.netty.TChannelBufferOutputTransport;
import io.airlift.drift.transport.server.ServerInvokeRequest;
import io.airlift.drift.transport.server.ServerMethodInvoker;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

/* loaded from: input_file:io/airlift/drift/transport/netty/server/ThriftServerHandler.class */
public class ThriftServerHandler extends ChannelDuplexHandler {
    private static final Logger log = Logger.get(ThriftServerHandler.class);
    private final ServerMethodInvoker methodInvoker;
    private final ScheduledExecutorService timeoutExecutor;
    private final Duration requestTimeout;

    public ThriftServerHandler(ServerMethodInvoker serverMethodInvoker, Duration duration, ScheduledExecutorService scheduledExecutorService) {
        this.methodInvoker = (ServerMethodInvoker) Objects.requireNonNull(serverMethodInvoker, "methodInvoker is null");
        this.requestTimeout = (Duration) Objects.requireNonNull(duration, "requestTimeout is null");
        this.timeoutExecutor = (ScheduledExecutorService) Objects.requireNonNull(scheduledExecutorService, "timeoutExecutor is null");
    }

    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        if (obj instanceof ThriftFrame) {
            messageReceived(channelHandlerContext, (ThriftFrame) obj);
        } else {
            channelHandlerContext.fireChannelRead(obj);
        }
    }

    private void messageReceived(final ChannelHandlerContext channelHandlerContext, ThriftFrame thriftFrame) {
        try {
            Futures.addCallback(decodeMessage(channelHandlerContext, thriftFrame.getProtocolFactory(), thriftFrame.getMessage(), thriftFrame.getHeaders(), thriftFrame.isSupportOutOfOrderResponse()), new FutureCallback<ThriftFrame>() { // from class: io.airlift.drift.transport.netty.server.ThriftServerHandler.1
                public void onSuccess(ThriftFrame thriftFrame2) {
                    channelHandlerContext.writeAndFlush(thriftFrame2);
                }

                public void onFailure(Throwable th) {
                    channelHandlerContext.disconnect();
                }
            });
        } catch (Exception e) {
            log.error(e, "Exception processing request");
            channelHandlerContext.disconnect();
        } catch (Throwable th) {
            log.error(th, "Error processing request");
            channelHandlerContext.disconnect();
            throw th;
        }
    }

    private ListenableFuture<ThriftFrame> decodeMessage(ChannelHandlerContext channelHandlerContext, TProtocolFactory tProtocolFactory, ByteBuf byteBuf, Map<String, String> map, boolean z) throws Exception {
        long nanoTime = System.nanoTime();
        TProtocol protocol = tProtocolFactory.getProtocol(new TChannelBufferInputTransport(byteBuf));
        TMessage readMessageBegin = protocol.readMessageBegin();
        Optional methodMetadata = this.methodInvoker.getMethodMetadata(readMessageBegin.getName());
        if (!methodMetadata.isPresent()) {
            return Futures.immediateFuture(writeApplicationException(channelHandlerContext, readMessageBegin.getName(), tProtocolFactory, readMessageBegin.getSequenceId(), z, TApplicationException.Type.UNKNOWN_METHOD, "Invalid method name: '" + readMessageBegin.getName() + "'", null));
        }
        MethodMetadata methodMetadata2 = (MethodMetadata) methodMetadata.get();
        if (readMessageBegin.getType() != 1 && readMessageBegin.getType() != 4) {
            return Futures.immediateFuture(writeApplicationException(channelHandlerContext, readMessageBegin.getName(), tProtocolFactory, readMessageBegin.getSequenceId(), z, TApplicationException.Type.INVALID_MESSAGE_TYPE, "Invalid method message type: '" + ((int) readMessageBegin.getType()) + "'", null));
        }
        ListenableFuture invoke = this.methodInvoker.invoke(new ServerInvokeRequest(methodMetadata2, map, readArguments(methodMetadata2, protocol)));
        this.methodInvoker.recordResult(readMessageBegin.getName(), nanoTime, invoke);
        return Futures.catchingAsync(Futures.withTimeout(Futures.transformAsync(invoke, obj -> {
            try {
                return Futures.immediateFuture(writeSuccessResponse(channelHandlerContext, methodMetadata2, tProtocolFactory, readMessageBegin.getSequenceId(), z, obj));
            } catch (Exception e) {
                return Futures.immediateFailedFuture(e);
            }
        }), this.requestTimeout.toMillis(), TimeUnit.MILLISECONDS, this.timeoutExecutor), Exception.class, exc -> {
            try {
                return Futures.immediateFuture(writeExceptionResponse(channelHandlerContext, methodMetadata2, tProtocolFactory, readMessageBegin.getSequenceId(), z, exc));
            } catch (Exception e) {
                return Futures.immediateFailedFuture(e);
            }
        });
    }

    private static List<Object> readArguments(MethodMetadata methodMetadata, TProtocol tProtocol) throws Exception {
        Object[] objArr = new Object[methodMetadata.getParameters().size()];
        ProtocolReader protocolReader = new ProtocolReader(tProtocol);
        protocolReader.readStructBegin();
        while (protocolReader.nextField()) {
            ParameterMetadata parameterByFieldId = methodMetadata.getParameterByFieldId(protocolReader.getFieldId());
            if (parameterByFieldId == null) {
                protocolReader.skipFieldData();
            } else {
                objArr[parameterByFieldId.getIndex()] = protocolReader.readField(parameterByFieldId.getCodec());
            }
        }
        protocolReader.readStructEnd();
        for (ParameterMetadata parameterMetadata : methodMetadata.getParameters()) {
            if (objArr[parameterMetadata.getIndex()] == null) {
                Type javaType = parameterMetadata.getCodec().getType().getJavaType();
                if (javaType instanceof Class) {
                    objArr[parameterMetadata.getIndex()] = Defaults.defaultValue(Primitives.unwrap((Class) javaType));
                }
            }
        }
        return Arrays.asList(objArr);
    }

    private static ThriftFrame writeSuccessResponse(ChannelHandlerContext channelHandlerContext, MethodMetadata methodMetadata, TProtocolFactory tProtocolFactory, int i, boolean z, Object obj) throws Exception {
        ByteBuf buffer = channelHandlerContext.alloc().buffer(1024);
        writeResponse(methodMetadata.getName(), tProtocolFactory.getProtocol(new TChannelBufferOutputTransport(buffer)), i, "success", (short) 0, methodMetadata.getResultCodec(), obj);
        return new ThriftFrame(OptionalInt.of(i), buffer, ImmutableMap.of(), tProtocolFactory, z);
    }

    private static ThriftFrame writeExceptionResponse(ChannelHandlerContext channelHandlerContext, MethodMetadata methodMetadata, TProtocolFactory tProtocolFactory, int i, boolean z, Throwable th) throws Exception {
        Optional exceptionId = methodMetadata.getExceptionId(th.getClass());
        if (!exceptionId.isPresent()) {
            return writeApplicationException(channelHandlerContext, methodMetadata.getName(), tProtocolFactory, i, z, TApplicationException.Type.INTERNAL_ERROR, "Internal error processing " + methodMetadata.getName() + ": " + th.getMessage(), th);
        }
        ByteBuf buffer = channelHandlerContext.alloc().buffer(1024);
        writeResponse(methodMetadata.getName(), tProtocolFactory.getProtocol(new TChannelBufferOutputTransport(buffer)), i, "exception", ((Short) exceptionId.get()).shortValue(), (ThriftCodec) methodMetadata.getExceptionCodecs().get(exceptionId.get()), th);
        return new ThriftFrame(OptionalInt.of(i), buffer, ImmutableMap.of(), tProtocolFactory, z);
    }

    private static ThriftFrame writeApplicationException(ChannelHandlerContext channelHandlerContext, String str, TProtocolFactory tProtocolFactory, int i, boolean z, TApplicationException.Type type, String str2, Throwable th) throws Exception {
        TApplicationException tApplicationException = new TApplicationException(type, str2);
        if (th != null) {
            tApplicationException.initCause(th);
        }
        TChannelBufferOutputTransport tChannelBufferOutputTransport = new TChannelBufferOutputTransport(channelHandlerContext.alloc().buffer(1024));
        TProtocol protocol = tProtocolFactory.getProtocol(tChannelBufferOutputTransport);
        protocol.writeMessageBegin(new TMessage(str, (byte) 3, i));
        ExceptionWriter.writeTApplicationException(tApplicationException, protocol);
        protocol.writeMessageEnd();
        return new ThriftFrame(OptionalInt.of(i), tChannelBufferOutputTransport.getOutputBuffer(), ImmutableMap.of(), tProtocolFactory, z);
    }

    private static void writeResponse(String str, TProtocol tProtocol, int i, String str2, short s, ThriftCodec<Object> thriftCodec, Object obj) throws Exception {
        tProtocol.writeMessageBegin(new TMessage(str, (byte) 2, i));
        ProtocolWriter protocolWriter = new ProtocolWriter(tProtocol);
        protocolWriter.writeStructBegin(str + "_result");
        protocolWriter.writeField(str2, s, thriftCodec, obj);
        protocolWriter.writeStructEnd();
        tProtocol.writeMessageEnd();
    }
}
