package io.airlift.drift.transport.netty;

import com.google.common.net.HostAndPort;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.airlift.drift.protocol.TBinaryProtocol;
import io.airlift.drift.protocol.TCompactProtocol;
import io.airlift.drift.protocol.TFacebookCompactProtocol;
import io.airlift.drift.transport.MethodInvoker;
import io.airlift.drift.transport.MethodInvokerFactory;
import io.airlift.drift.transport.netty.DriftNettyClientConfig;
import io.airlift.units.Duration;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.util.concurrent.Future;
import java.io.Closeable;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.PreDestroy;
import javax.inject.Inject;

/* loaded from: input_file:io/airlift/drift/transport/netty/DriftNettyMethodInvokerFactory.class */
public class DriftNettyMethodInvokerFactory<I> implements MethodInvokerFactory<I>, Closeable {
    private final Function<I, DriftNettyClientConfig> clientConfigurationProvider;
    private final EventLoopGroup group;
    private final SslContextFactory sslContextFactory;
    private final HostAndPort defaultSocksProxy;

    public static DriftNettyMethodInvokerFactory<?> createStaticDriftNettyMethodInvokerFactory(DriftNettyClientConfig driftNettyClientConfig) {
        return new DriftNettyMethodInvokerFactory<>(new DriftNettyConnectionFactoryConfig(), obj -> {
            return driftNettyClientConfig;
        });
    }

    @Inject
    public DriftNettyMethodInvokerFactory(DriftNettyConnectionFactoryConfig driftNettyConnectionFactoryConfig, Function<I, DriftNettyClientConfig> function) {
        Objects.requireNonNull(driftNettyConnectionFactoryConfig, "factoryConfig is null");
        this.group = new NioEventLoopGroup(driftNettyConnectionFactoryConfig.getThreadCount(), new ThreadFactoryBuilder().setNameFormat("drift-client-%s").setDaemon(true).build());
        this.clientConfigurationProvider = (Function) Objects.requireNonNull(function, "clientConfigurationProvider is null");
        this.sslContextFactory = SslContextFactory.createSslContextFactory(driftNettyConnectionFactoryConfig.getSslContextRefreshTime(), this.group);
        this.defaultSocksProxy = driftNettyConnectionFactoryConfig.getSocksProxy();
    }

    public MethodInvoker createMethodInvoker(I i) {
        TBinaryProtocol.Factory factory;
        MessageFraming lengthPrefixedMessageFraming;
        MessageEncoding headerMessageEncoding;
        DriftNettyClientConfig apply = this.clientConfigurationProvider.apply(i);
        if (apply.getSocksProxy() == null) {
            apply.setSocksProxy(this.defaultSocksProxy);
        }
        switch (apply.getProtocol()) {
            case BINARY:
                factory = new TBinaryProtocol.Factory(false, true, -1L, apply.getMaxFrameSize().toBytes());
                break;
            case COMPACT:
                if (apply.getTransport() != DriftNettyClientConfig.Transport.HEADER) {
                    factory = new TCompactProtocol.Factory(-1L, apply.getMaxFrameSize().toBytes());
                    break;
                } else {
                    factory = new TFacebookCompactProtocol.Factory(Math.toIntExact(apply.getMaxFrameSize().toBytes()));
                    break;
                }
            default:
                throw new IllegalArgumentException("Unknown protocol: " + apply.getProtocol());
        }
        switch (apply.getTransport()) {
            case UNFRAMED:
                lengthPrefixedMessageFraming = new NoMessageFraming(factory, apply.getMaxFrameSize());
                headerMessageEncoding = new SimpleMessageEncoding(factory);
                break;
            case FRAMED:
                lengthPrefixedMessageFraming = new LengthPrefixedMessageFraming(apply.getMaxFrameSize());
                headerMessageEncoding = new SimpleMessageEncoding(factory);
                break;
            case HEADER:
                lengthPrefixedMessageFraming = new LengthPrefixedMessageFraming(apply.getMaxFrameSize());
                headerMessageEncoding = new HeaderMessageEncoding(factory);
                break;
            default:
                throw new IllegalArgumentException("Unknown transport: " + apply.getTransport());
        }
        Optional empty = Optional.empty();
        if (apply.isSslEnabled()) {
            empty = Optional.of(this.sslContextFactory.get(apply.getTrustCertificate(), Optional.ofNullable(apply.getKey()), Optional.ofNullable(apply.getKey()), Optional.ofNullable(apply.getKeyPassword()), apply.getSessionCacheSize(), apply.getSessionTimeout(), apply.getCiphers()));
            ((Supplier) empty.get()).get();
        }
        ConnectionManager connectionFactory = new ConnectionFactory(this.group, lengthPrefixedMessageFraming, headerMessageEncoding, empty, apply);
        if (apply.isPoolEnabled()) {
            connectionFactory = new ConnectionPool(connectionFactory, this.group, apply);
        }
        return new DriftNettyMethodInvoker(connectionFactory, this.group, new Duration(TimeUnit.SECONDS.toMillis(10L) + apply.getConnectTimeout().toMillis() + apply.getRequestTimeout().toMillis(), TimeUnit.MILLISECONDS));
    }

    @PreDestroy
    public void shutdownGracefully() {
        shutdownGracefully(true);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        shutdownGracefully(false);
    }

    private void shutdownGracefully(boolean z) {
        Future shutdownGracefully = this.group.shutdownGracefully();
        if (z) {
            try {
                shutdownGracefully.get();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            } catch (ExecutionException e2) {
            }
        }
    }
}
