package org.apache.spark.network.sasl;

import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import java.io.File;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import javax.security.sasl.SaslException;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.sasl.SaslEncryption;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/spark/network/sasl/SparkSaslSuite.class */
public class SparkSaslSuite {
    private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { // from class: org.apache.spark.network.sasl.SparkSaslSuite.1
        public String getSaslUser(String str) {
            return "user";
        }

        public String getSecretKey(String str) {
            return str;
        }
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/network/sasl/SparkSaslSuite$EncryptionCheckerBootstrap.class */
    public static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter implements TransportServerBootstrap {
        boolean foundEncryptionHandler;
        String encryptHandlerName;

        EncryptionCheckerBootstrap(String str) {
            this.encryptHandlerName = str;
        }

        public void write(ChannelHandlerContext channelHandlerContext, Object obj, ChannelPromise channelPromise) throws Exception {
            if (!this.foundEncryptionHandler) {
                this.foundEncryptionHandler = channelHandlerContext.channel().pipeline().get(this.encryptHandlerName) != null;
            }
            channelHandlerContext.write(obj, channelPromise);
        }

        public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
            channel.pipeline().addFirst("encryptionChecker", this);
            return rpcHandler;
        }
    }

    /* loaded from: input_file:org/apache/spark/network/sasl/SparkSaslSuite$EncryptionDisablerBootstrap.class */
    private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
        private EncryptionDisablerBootstrap() {
        }

        public void doBootstrap(TransportClient transportClient, Channel channel) {
            channel.pipeline().remove("saslEncryption");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/network/sasl/SparkSaslSuite$SaslTestCtx.class */
    public static class SaslTestCtx implements AutoCloseable {
        final TransportClient client;
        final TransportServer server;
        final TransportContext ctx;
        private final boolean encrypt;
        private final boolean disableClientEncryption;
        private final EncryptionCheckerBootstrap checker;

        SaslTestCtx(RpcHandler rpcHandler, boolean z, boolean z2) throws Exception {
            this(rpcHandler, z, z2, Collections.emptyMap());
        }

        SaslTestCtx(RpcHandler rpcHandler, boolean z, boolean z2, Map<String, String> map) throws Exception {
            TransportConf transportConf = new TransportConf("shuffle", new MapConfigProvider(ImmutableMap.builder().putAll(map).put("spark.authenticate.enableSaslEncryption", String.valueOf(z)).build()));
            SecretKeyHolder secretKeyHolder = (SecretKeyHolder) Mockito.mock(SecretKeyHolder.class);
            Mockito.when(secretKeyHolder.getSaslUser(Mockito.anyString())).thenReturn("user");
            Mockito.when(secretKeyHolder.getSecretKey(Mockito.anyString())).thenReturn("secret");
            this.ctx = new TransportContext(transportConf, rpcHandler);
            this.checker = new EncryptionCheckerBootstrap("saslEncryption");
            this.server = this.ctx.createServer(Arrays.asList(new SaslServerBootstrap(transportConf, secretKeyHolder), this.checker));
            try {
                ArrayList arrayList = new ArrayList();
                arrayList.add(new SaslClientBootstrap(transportConf, "user", secretKeyHolder));
                if (z2) {
                    arrayList.add(new EncryptionDisablerBootstrap());
                }
                this.client = this.ctx.createClientFactory(arrayList).createClient(TestUtils.getLocalHost(), this.server.getPort());
                this.encrypt = z;
                this.disableClientEncryption = z2;
            } catch (Exception e) {
                close();
                throw e;
            }
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            if (!this.disableClientEncryption) {
                Assert.assertEquals(Boolean.valueOf(this.encrypt), Boolean.valueOf(this.checker.foundEncryptionHandler));
            }
            if (this.client != null) {
                this.client.close();
            }
            if (this.server != null) {
                this.server.close();
            }
            if (this.ctx != null) {
                this.ctx.close();
            }
        }
    }

    @Test
    public void testMatching() {
        SparkSaslClient sparkSaslClient = new SparkSaslClient("shared-secret", this.secretKeyHolder, false);
        SparkSaslServer sparkSaslServer = new SparkSaslServer("shared-secret", this.secretKeyHolder, false);
        Assert.assertFalse(sparkSaslClient.isComplete());
        Assert.assertFalse(sparkSaslServer.isComplete());
        byte[] firstToken = sparkSaslClient.firstToken();
        while (true) {
            byte[] bArr = firstToken;
            if (sparkSaslClient.isComplete()) {
                Assert.assertTrue(sparkSaslServer.isComplete());
                sparkSaslServer.dispose();
                Assert.assertFalse(sparkSaslServer.isComplete());
                sparkSaslClient.dispose();
                Assert.assertFalse(sparkSaslClient.isComplete());
                return;
            }
            firstToken = sparkSaslClient.response(sparkSaslServer.response(bArr));
        }
    }

    @Test
    public void testNonMatching() {
        SparkSaslClient sparkSaslClient = new SparkSaslClient("my-secret", this.secretKeyHolder, false);
        SparkSaslServer sparkSaslServer = new SparkSaslServer("your-secret", this.secretKeyHolder, false);
        Assert.assertFalse(sparkSaslClient.isComplete());
        Assert.assertFalse(sparkSaslServer.isComplete());
        byte[] firstToken = sparkSaslClient.firstToken();
        while (!sparkSaslClient.isComplete()) {
            try {
                firstToken = sparkSaslClient.response(sparkSaslServer.response(firstToken));
            } catch (Exception e) {
                Assert.assertTrue(e.getMessage().contains("Mismatched response"));
                Assert.assertFalse(sparkSaslClient.isComplete());
                Assert.assertFalse(sparkSaslServer.isComplete());
                return;
            }
        }
        Assert.fail("Should not have completed");
    }

    @Test
    public void testSaslAuthentication() throws Throwable {
        testBasicSasl(false);
    }

    @Test
    public void testSaslEncryption() throws Throwable {
        testBasicSasl(true);
    }

    private static void testBasicSasl(boolean z) throws Throwable {
        RpcHandler rpcHandler = (RpcHandler) Mockito.mock(RpcHandler.class);
        ((RpcHandler) Mockito.doAnswer(invocationOnMock -> {
            ByteBuffer byteBuffer = (ByteBuffer) invocationOnMock.getArguments()[1];
            RpcResponseCallback rpcResponseCallback = (RpcResponseCallback) invocationOnMock.getArguments()[2];
            Assert.assertEquals("Ping", JavaUtils.bytesToString(byteBuffer));
            rpcResponseCallback.onSuccess(JavaUtils.stringToBytes("Pong"));
            return null;
        }).when(rpcHandler)).receive((TransportClient) Mockito.any(TransportClient.class), (ByteBuffer) Mockito.any(ByteBuffer.class), (RpcResponseCallback) Mockito.any(RpcResponseCallback.class));
        try {
            SaslTestCtx saslTestCtx = new SaslTestCtx(rpcHandler, z, false);
            Throwable th = null;
            try {
                try {
                    Assert.assertEquals("Pong", JavaUtils.bytesToString(saslTestCtx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10L))));
                    if (saslTestCtx != null) {
                        if (0 != 0) {
                            try {
                                saslTestCtx.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            saslTestCtx.close();
                        }
                    }
                    Throwable th3 = null;
                    long nanoTime = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10L, TimeUnit.SECONDS);
                    while (nanoTime > System.nanoTime()) {
                        try {
                            ((RpcHandler) Mockito.verify(rpcHandler, Mockito.times(2))).channelInactive((TransportClient) Mockito.any(TransportClient.class));
                            th3 = null;
                            break;
                        } catch (Throwable th4) {
                            th3 = th4;
                            TimeUnit.MILLISECONDS.sleep(10L);
                        }
                    }
                    if (th3 != null) {
                        throw th3;
                    }
                } finally {
                }
            } finally {
            }
        } catch (Throwable th5) {
            Throwable th6 = null;
            long nanoTime2 = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10L, TimeUnit.SECONDS);
            while (nanoTime2 > System.nanoTime()) {
                try {
                    ((RpcHandler) Mockito.verify(rpcHandler, Mockito.times(2))).channelInactive((TransportClient) Mockito.any(TransportClient.class));
                    th6 = null;
                    break;
                } catch (Throwable th7) {
                    th6 = th7;
                    TimeUnit.MILLISECONDS.sleep(10L);
                }
            }
            if (th6 == null) {
                throw th5;
            }
            throw th6;
        }
    }

    @Test
    public void testEncryptedMessage() throws Exception {
        SaslEncryptionBackend saslEncryptionBackend = (SaslEncryptionBackend) Mockito.mock(SaslEncryptionBackend.class);
        byte[] bArr = new byte[1024];
        new Random().nextBytes(bArr);
        Mockito.when(saslEncryptionBackend.wrap((byte[]) Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt())).thenReturn(bArr);
        ByteBuf buffer = Unpooled.buffer();
        try {
            buffer.writeBytes(bArr);
            ByteArrayWritableChannel byteArrayWritableChannel = new ByteArrayWritableChannel(32);
            SaslEncryption.EncryptedMessage encryptedMessage = new SaslEncryption.EncryptedMessage(saslEncryptionBackend, buffer, 1024);
            long transferTo = encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transferred());
            Assert.assertTrue(transferTo < ((long) bArr.length));
            Assert.assertTrue(transferTo > 0);
            Assert.assertEquals(0L, encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transferred()));
            byteArrayWritableChannel.reset();
            Assert.assertEquals(1L, encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transferred()));
            for (int i = 0; i < (bArr.length / 32) - 2; i++) {
                byteArrayWritableChannel.reset();
                Assert.assertEquals(1L, encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transferred()));
            }
            byteArrayWritableChannel.reset();
            long transferTo2 = encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transferred());
            Assert.assertTrue("Unexpected count: " + transferTo2, transferTo2 > 1 && transferTo2 < ((long) bArr.length));
            Assert.assertEquals(bArr.length, encryptedMessage.transferred());
            buffer.release();
        } catch (Throwable th) {
            buffer.release();
            throw th;
        }
    }

    @Test
    public void testEncryptedMessageChunking() throws Exception {
        File createTempFile = File.createTempFile("sasltest", ".txt");
        try {
            TransportConf transportConf = new TransportConf("shuffle", MapConfigProvider.EMPTY);
            byte[] bArr = new byte[8192];
            new Random().nextBytes(bArr);
            Files.write(bArr, createTempFile);
            SaslEncryptionBackend saslEncryptionBackend = (SaslEncryptionBackend) Mockito.mock(SaslEncryptionBackend.class);
            Mockito.when(saslEncryptionBackend.wrap((byte[]) Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt())).thenReturn(bArr);
            SaslEncryption.EncryptedMessage encryptedMessage = new SaslEncryption.EncryptedMessage(saslEncryptionBackend, new FileSegmentManagedBuffer(transportConf, createTempFile, 0L, createTempFile.length()).convertToNetty(), bArr.length / 8);
            ByteArrayWritableChannel byteArrayWritableChannel = new ByteArrayWritableChannel(bArr.length);
            while (encryptedMessage.transferred() < encryptedMessage.count()) {
                byteArrayWritableChannel.reset();
                encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transferred());
            }
            ((SaslEncryptionBackend) Mockito.verify(saslEncryptionBackend, Mockito.times(8))).wrap((byte[]) Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt());
            createTempFile.delete();
        } catch (Throwable th) {
            createTempFile.delete();
            throw th;
        }
    }

    @Test
    public void testFileRegionEncryption() throws Exception {
        ImmutableMap of = ImmutableMap.of("spark.network.sasl.maxEncryptedBlockSize", "1k");
        AtomicReference atomicReference = new AtomicReference();
        File createTempFile = File.createTempFile("sasltest", ".txt");
        SaslTestCtx saslTestCtx = null;
        try {
            TransportConf transportConf = new TransportConf("shuffle", new MapConfigProvider(of));
            StreamManager streamManager = (StreamManager) Mockito.mock(StreamManager.class);
            Mockito.when(streamManager.getChunk(Mockito.anyLong(), Mockito.anyInt())).thenAnswer(invocationOnMock -> {
                return new FileSegmentManagedBuffer(transportConf, createTempFile, 0L, createTempFile.length());
            });
            RpcHandler rpcHandler = (RpcHandler) Mockito.mock(RpcHandler.class);
            Mockito.when(rpcHandler.getStreamManager()).thenReturn(streamManager);
            byte[] bArr = new byte[8192];
            new Random().nextBytes(bArr);
            Files.write(bArr, createTempFile);
            saslTestCtx = new SaslTestCtx(rpcHandler, true, false, of);
            CountDownLatch countDownLatch = new CountDownLatch(1);
            ChunkReceivedCallback chunkReceivedCallback = (ChunkReceivedCallback) Mockito.mock(ChunkReceivedCallback.class);
            ((ChunkReceivedCallback) Mockito.doAnswer(invocationOnMock2 -> {
                atomicReference.set((ManagedBuffer) invocationOnMock2.getArguments()[1]);
                ((ManagedBuffer) atomicReference.get()).retain();
                countDownLatch.countDown();
                return null;
            }).when(chunkReceivedCallback)).onSuccess(Mockito.anyInt(), (ManagedBuffer) Mockito.any(ManagedBuffer.class));
            saslTestCtx.client.fetchChunk(0L, 0, chunkReceivedCallback);
            countDownLatch.await(10L, TimeUnit.SECONDS);
            ((ChunkReceivedCallback) Mockito.verify(chunkReceivedCallback, Mockito.times(1))).onSuccess(Mockito.anyInt(), (ManagedBuffer) Mockito.any(ManagedBuffer.class));
            ((ChunkReceivedCallback) Mockito.verify(chunkReceivedCallback, Mockito.never())).onFailure(Mockito.anyInt(), (Throwable) Mockito.any(Throwable.class));
            Assert.assertArrayEquals(bArr, ByteStreams.toByteArray(((ManagedBuffer) atomicReference.get()).createInputStream()));
            createTempFile.delete();
            if (saslTestCtx != null) {
                saslTestCtx.close();
            }
            if (atomicReference.get() != null) {
                ((ManagedBuffer) atomicReference.get()).release();
            }
        } catch (Throwable th) {
            createTempFile.delete();
            if (saslTestCtx != null) {
                saslTestCtx.close();
            }
            if (atomicReference.get() != null) {
                ((ManagedBuffer) atomicReference.get()).release();
            }
            throw th;
        }
    }

    @Test
    public void testServerAlwaysEncrypt() {
        Assert.assertTrue(((Exception) Assert.assertThrows(Exception.class, () -> {
            new SaslTestCtx((RpcHandler) Mockito.mock(RpcHandler.class), false, false, ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true"));
        })).getCause() instanceof SaslException);
    }

    @Test
    public void testDataEncryptionIsActuallyEnabled() throws Exception {
        SaslTestCtx saslTestCtx = new SaslTestCtx((RpcHandler) Mockito.mock(RpcHandler.class), true, true);
        Throwable th = null;
        try {
            Assert.assertFalse(((Exception) Assert.assertThrows(Exception.class, () -> {
                saslTestCtx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10L));
            })).getCause() instanceof TimeoutException);
            if (saslTestCtx != null) {
                if (0 == 0) {
                    saslTestCtx.close();
                    return;
                }
                try {
                    saslTestCtx.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (saslTestCtx != null) {
                if (0 != 0) {
                    try {
                        saslTestCtx.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    saslTestCtx.close();
                }
            }
            throw th3;
        }
    }

    @Test
    public void testRpcHandlerDelegate() throws Exception {
        RpcHandler rpcHandler = (RpcHandler) Mockito.mock(RpcHandler.class);
        SaslRpcHandler saslRpcHandler = new SaslRpcHandler((TransportConf) null, (Channel) null, rpcHandler, (SecretKeyHolder) null);
        saslRpcHandler.getStreamManager();
        ((RpcHandler) Mockito.verify(rpcHandler)).getStreamManager();
        saslRpcHandler.channelInactive((TransportClient) null);
        ((RpcHandler) Mockito.verify(rpcHandler)).channelInactive((TransportClient) Mockito.isNull());
        saslRpcHandler.exceptionCaught((Throwable) null, (TransportClient) null);
        ((RpcHandler) Mockito.verify(rpcHandler)).exceptionCaught((Throwable) Mockito.isNull(), (TransportClient) Mockito.isNull());
    }

    @Test
    public void testDelegates() throws Exception {
        for (Method method : RpcHandler.class.getDeclaredMethods()) {
            Assert.assertNotEquals(SaslRpcHandler.class.getMethod(method.getName(), method.getParameterTypes()).getDeclaringClass(), RpcHandler.class);
        }
    }
}
