package org.neo4j.causalclustering.protocol.handshake;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.neo4j.causalclustering.messaging.SimpleNettyChannel;
import org.neo4j.causalclustering.protocol.Protocol;
import org.neo4j.causalclustering.protocol.handshake.TestProtocols;
import org.neo4j.logging.NullLog;

/* loaded from: input_file:org/neo4j/causalclustering/protocol/handshake/NettyProtocolHandshakeIT.class */
public class NettyProtocolHandshakeIT {
    private ApplicationSupportedProtocols supportedRaftApplicationProtocol = new ApplicationSupportedProtocols(Protocol.ApplicationProtocolCategory.RAFT, Collections.emptyList());
    private ApplicationSupportedProtocols supportedCatchupApplicationProtocol = new ApplicationSupportedProtocols(Protocol.ApplicationProtocolCategory.CATCHUP, Collections.emptyList());
    private Collection<ModifierSupportedProtocols> supportedCompressionModifierProtocols = Arrays.asList(new ModifierSupportedProtocols(Protocol.ModifierProtocolCategory.COMPRESSION, TestProtocols.TestModifierProtocols.listVersionsOf(Protocol.ModifierProtocolCategory.COMPRESSION)));
    private Collection<ModifierSupportedProtocols> noSupportedModifierProtocols = Collections.emptyList();
    private ApplicationProtocolRepository raftApplicationProtocolRepository = new ApplicationProtocolRepository(TestProtocols.TestApplicationProtocols.values(), this.supportedRaftApplicationProtocol);
    private ApplicationProtocolRepository catchupApplicationProtocolRepository = new ApplicationProtocolRepository(TestProtocols.TestApplicationProtocols.values(), this.supportedCatchupApplicationProtocol);
    private ModifierProtocolRepository compressionModifierProtocolRepository = new ModifierProtocolRepository(TestProtocols.TestModifierProtocols.values(), this.supportedCompressionModifierProtocols);
    private ModifierProtocolRepository unsupportingModifierProtocolRepository = new ModifierProtocolRepository(TestProtocols.TestModifierProtocols.values(), this.noSupportedModifierProtocols);
    private Server server;
    private HandshakeClient handshakeClient;
    private Client client;

    /* loaded from: input_file:org/neo4j/causalclustering/protocol/handshake/NettyProtocolHandshakeIT$Client.class */
    private static class Client {
        Bootstrap bootstrap;
        NioEventLoopGroup eventLoopGroup = new NioEventLoopGroup();
        Channel channel;

        Client(HandshakeClient handshakeClient) {
            this.bootstrap = new Bootstrap().group(this.eventLoopGroup).channel(NioSocketChannel.class).handler(new ClientInitializer(handshakeClient));
        }

        void connect(int i) {
            this.channel = this.bootstrap.connect("localhost", i).awaitUninterruptibly().channel();
        }

        void disconnect() {
            if (this.channel != null) {
                this.channel.close().awaitUninterruptibly();
                this.eventLoopGroup.shutdownGracefully(0L, 0L, TimeUnit.SECONDS);
            }
        }
    }

    /* loaded from: input_file:org/neo4j/causalclustering/protocol/handshake/NettyProtocolHandshakeIT$ClientInitializer.class */
    static class ClientInitializer extends ChannelInitializer<SocketChannel> {
        private final HandshakeClient handshakeClient;

        ClientInitializer(HandshakeClient handshakeClient) {
            this.handshakeClient = handshakeClient;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void initChannel(SocketChannel socketChannel) {
            ChannelPipeline pipeline = socketChannel.pipeline();
            pipeline.addLast("frameEncoder", new LengthFieldPrepender(4));
            pipeline.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
            pipeline.addLast("requestMessageEncoder", new ClientMessageEncoder());
            pipeline.addLast("responseMessageDecoder", new ClientMessageDecoder());
            pipeline.addLast(new ChannelHandler[]{new NettyHandshakeClient(this.handshakeClient)});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/causalclustering/protocol/handshake/NettyProtocolHandshakeIT$Server.class */
    public static class Server {
        Channel channel;
        NioEventLoopGroup eventLoopGroup;
        HandshakeServer handshakeServer;

        private Server() {
        }

        void start(final ApplicationProtocolRepository applicationProtocolRepository, final ModifierProtocolRepository modifierProtocolRepository) {
            this.eventLoopGroup = new NioEventLoopGroup();
            this.channel = new ServerBootstrap().group(this.eventLoopGroup).channel(NioServerSocketChannel.class).option(ChannelOption.SO_REUSEADDR, true).localAddress(0).childHandler(new ChannelInitializer<SocketChannel>() { // from class: org.neo4j.causalclustering.protocol.handshake.NettyProtocolHandshakeIT.Server.1
                /* JADX INFO: Access modifiers changed from: protected */
                public void initChannel(SocketChannel socketChannel) {
                    ChannelPipeline pipeline = socketChannel.pipeline();
                    Server.this.handshakeServer = new HandshakeServer(applicationProtocolRepository, modifierProtocolRepository, new SimpleNettyChannel(socketChannel, NullLog.getInstance()));
                    pipeline.addLast("frameEncoder", new LengthFieldPrepender(4));
                    pipeline.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
                    pipeline.addLast("responseMessageEncoder", new ServerMessageEncoder());
                    pipeline.addLast("requestMessageDecoder", new ServerMessageDecoder());
                    pipeline.addLast(new ChannelHandler[]{new NettyHandshakeServer(Server.this.handshakeServer)});
                }
            }).bind().syncUninterruptibly().channel();
        }

        void stop() {
            this.channel.close().awaitUninterruptibly();
            this.channel = null;
            this.eventLoopGroup.shutdownGracefully(0L, 0L, TimeUnit.SECONDS);
        }

        int port() {
            return ((InetSocketAddress) this.channel.localAddress()).getPort();
        }
    }

    @Before
    public void setUp() {
        this.server = new Server();
        this.server.start(this.raftApplicationProtocolRepository, this.compressionModifierProtocolRepository);
        this.handshakeClient = new HandshakeClient();
        this.client = new Client(this.handshakeClient);
        this.client.connect(this.server.port());
    }

    @After
    public void tearDown() {
        this.client.disconnect();
        this.server.stop();
    }

    @Test
    public void shouldSuccessfullyHandshakeKnownProtocolOnClientWithCompression() throws Exception {
        ProtocolStack protocolStack = (ProtocolStack) this.handshakeClient.initiate(new SimpleNettyChannel(this.client.channel, NullLog.getInstance()), this.raftApplicationProtocolRepository, this.compressionModifierProtocolRepository).get(1L, TimeUnit.MINUTES);
        Assert.assertThat(protocolStack.applicationProtocol(), Matchers.equalTo(TestProtocols.TestApplicationProtocols.latest(Protocol.ApplicationProtocolCategory.RAFT)));
        Assert.assertThat(protocolStack.modifierProtocols(), Matchers.contains(new Protocol.ModifierProtocol[]{TestProtocols.TestModifierProtocols.latest(Protocol.ModifierProtocolCategory.COMPRESSION)}));
    }

    @Test
    public void shouldSuccessfullyHandshakeKnownProtocolOnServerWithCompression() throws Exception {
        ProtocolStack protocolStack = getServerHandshakeFuture(this.handshakeClient.initiate(new SimpleNettyChannel(this.client.channel, NullLog.getInstance()), this.raftApplicationProtocolRepository, this.compressionModifierProtocolRepository)).get(1L, TimeUnit.MINUTES);
        Assert.assertThat(protocolStack.applicationProtocol(), Matchers.equalTo(TestProtocols.TestApplicationProtocols.latest(Protocol.ApplicationProtocolCategory.RAFT)));
        Assert.assertThat(protocolStack.modifierProtocols(), Matchers.contains(new Protocol.ModifierProtocol[]{TestProtocols.TestModifierProtocols.latest(Protocol.ModifierProtocolCategory.COMPRESSION)}));
    }

    @Test
    public void shouldSuccessfullyHandshakeKnownProtocolOnClientNoModifiers() throws Exception {
        ProtocolStack protocolStack = (ProtocolStack) this.handshakeClient.initiate(new SimpleNettyChannel(this.client.channel, NullLog.getInstance()), this.raftApplicationProtocolRepository, this.unsupportingModifierProtocolRepository).get(1L, TimeUnit.MINUTES);
        Assert.assertThat(protocolStack.applicationProtocol(), Matchers.equalTo(TestProtocols.TestApplicationProtocols.latest(Protocol.ApplicationProtocolCategory.RAFT)));
        Assert.assertThat(protocolStack.modifierProtocols(), Matchers.empty());
    }

    @Test
    public void shouldSuccessfullyHandshakeKnownProtocolOnServerNoModifiers() throws Exception {
        ProtocolStack protocolStack = getServerHandshakeFuture(this.handshakeClient.initiate(new SimpleNettyChannel(this.client.channel, NullLog.getInstance()), this.raftApplicationProtocolRepository, this.unsupportingModifierProtocolRepository)).get(1L, TimeUnit.MINUTES);
        Assert.assertThat(protocolStack.applicationProtocol(), Matchers.equalTo(TestProtocols.TestApplicationProtocols.latest(Protocol.ApplicationProtocolCategory.RAFT)));
        Assert.assertThat(protocolStack.modifierProtocols(), Matchers.empty());
    }

    @Test(expected = ClientHandshakeException.class)
    public void shouldFailHandshakeForUnknownProtocolOnClient() throws Throwable {
        try {
            this.handshakeClient.initiate(new SimpleNettyChannel(this.client.channel, NullLog.getInstance()), this.catchupApplicationProtocolRepository, this.compressionModifierProtocolRepository).get(1L, TimeUnit.MINUTES);
        } catch (ExecutionException e) {
            throw e.getCause();
        }
    }

    @Test(expected = ServerHandshakeException.class)
    public void shouldFailHandshakeForUnknownProtocolOnServer() throws Throwable {
        try {
            getServerHandshakeFuture(this.handshakeClient.initiate(new SimpleNettyChannel(this.client.channel, NullLog.getInstance()), this.catchupApplicationProtocolRepository, this.compressionModifierProtocolRepository)).get(1L, TimeUnit.MINUTES);
        } catch (ExecutionException e) {
            throw e.getCause();
        }
    }

    private CompletableFuture<ProtocolStack> getServerHandshakeFuture(CompletableFuture<ProtocolStack> completableFuture) {
        return completableFuture.handle((protocolStack, th) -> {
            return null;
        }).thenCompose((Function<? super U, ? extends CompletionStage<U>>) obj -> {
            return this.server.handshakeServer.protocolStackFuture();
        });
    }
}
