package org.apache.spark.network.sasl;

import com.google.common.collect.Lists;
import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import javax.security.sasl.SaslException;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.sasl.SaslEncryption;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/* loaded from: input_file:org/apache/spark/network/sasl/SparkSaslSuite.class */
public class SparkSaslSuite {
    private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { // from class: org.apache.spark.network.sasl.SparkSaslSuite.1
        public String getSaslUser(String str) {
            return "user";
        }

        public String getSecretKey(String str) {
            return str;
        }
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/network/sasl/SparkSaslSuite$EncryptionCheckerBootstrap.class */
    public static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter implements TransportServerBootstrap {
        boolean foundEncryptionHandler;

        private EncryptionCheckerBootstrap() {
        }

        public void write(ChannelHandlerContext channelHandlerContext, Object obj, ChannelPromise channelPromise) throws Exception {
            if (!this.foundEncryptionHandler) {
                this.foundEncryptionHandler = channelHandlerContext.channel().pipeline().get("saslEncryption") != null;
            }
            channelHandlerContext.write(obj, channelPromise);
        }

        public void handlerRemoved(ChannelHandlerContext channelHandlerContext) throws Exception {
            super.handlerRemoved(channelHandlerContext);
        }

        public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
            channel.pipeline().addFirst("encryptionChecker", this);
            return rpcHandler;
        }
    }

    /* loaded from: input_file:org/apache/spark/network/sasl/SparkSaslSuite$EncryptionDisablerBootstrap.class */
    private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
        private EncryptionDisablerBootstrap() {
        }

        public void doBootstrap(TransportClient transportClient, Channel channel) {
            channel.pipeline().remove("saslEncryption");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/spark/network/sasl/SparkSaslSuite$SaslTestCtx.class */
    public static class SaslTestCtx {
        final TransportClient client;
        final TransportServer server;
        private final boolean encrypt;
        private final boolean disableClientEncryption;
        private final EncryptionCheckerBootstrap checker;

        SaslTestCtx(RpcHandler rpcHandler, boolean z, boolean z2) throws Exception {
            TransportConf transportConf = new TransportConf(new SystemPropertyConfigProvider());
            SecretKeyHolder secretKeyHolder = (SecretKeyHolder) Mockito.mock(SecretKeyHolder.class);
            Mockito.when(secretKeyHolder.getSaslUser(Mockito.anyString())).thenReturn("user");
            Mockito.when(secretKeyHolder.getSecretKey(Mockito.anyString())).thenReturn("secret");
            TransportContext transportContext = new TransportContext(transportConf, rpcHandler);
            this.checker = new EncryptionCheckerBootstrap();
            this.server = transportContext.createServer(Arrays.asList(new SaslServerBootstrap(transportConf, secretKeyHolder), this.checker));
            try {
                ArrayList newArrayList = Lists.newArrayList();
                newArrayList.add(new SaslClientBootstrap(transportConf, "user", secretKeyHolder, z));
                if (z2) {
                    newArrayList.add(new EncryptionDisablerBootstrap());
                }
                this.client = transportContext.createClientFactory(newArrayList).createClient(TestUtils.getLocalHost(), this.server.getPort());
                this.encrypt = z;
                this.disableClientEncryption = z2;
            } catch (Exception e) {
                close();
                throw e;
            }
        }

        void close() {
            if (!this.disableClientEncryption) {
                Assert.assertEquals(Boolean.valueOf(this.encrypt), Boolean.valueOf(this.checker.foundEncryptionHandler));
            }
            if (this.client != null) {
                this.client.close();
            }
            if (this.server != null) {
                this.server.close();
            }
        }
    }

    @Test
    public void testMatching() {
        SparkSaslClient sparkSaslClient = new SparkSaslClient("shared-secret", this.secretKeyHolder, false);
        SparkSaslServer sparkSaslServer = new SparkSaslServer("shared-secret", this.secretKeyHolder, false);
        Assert.assertFalse(sparkSaslClient.isComplete());
        Assert.assertFalse(sparkSaslServer.isComplete());
        byte[] firstToken = sparkSaslClient.firstToken();
        while (true) {
            byte[] bArr = firstToken;
            if (sparkSaslClient.isComplete()) {
                Assert.assertTrue(sparkSaslServer.isComplete());
                sparkSaslServer.dispose();
                Assert.assertFalse(sparkSaslServer.isComplete());
                sparkSaslClient.dispose();
                Assert.assertFalse(sparkSaslClient.isComplete());
                return;
            }
            firstToken = sparkSaslClient.response(sparkSaslServer.response(bArr));
        }
    }

    @Test
    public void testNonMatching() {
        SparkSaslClient sparkSaslClient = new SparkSaslClient("my-secret", this.secretKeyHolder, false);
        SparkSaslServer sparkSaslServer = new SparkSaslServer("your-secret", this.secretKeyHolder, false);
        Assert.assertFalse(sparkSaslClient.isComplete());
        Assert.assertFalse(sparkSaslServer.isComplete());
        byte[] firstToken = sparkSaslClient.firstToken();
        while (!sparkSaslClient.isComplete()) {
            try {
                firstToken = sparkSaslClient.response(sparkSaslServer.response(firstToken));
            } catch (Exception e) {
                Assert.assertTrue(e.getMessage().contains("Mismatched response"));
                Assert.assertFalse(sparkSaslClient.isComplete());
                Assert.assertFalse(sparkSaslServer.isComplete());
                return;
            }
        }
        Assert.fail("Should not have completed");
    }

    @Test
    public void testSaslAuthentication() throws Exception {
        testBasicSasl(false);
    }

    @Test
    public void testSaslEncryption() throws Exception {
        testBasicSasl(true);
    }

    private void testBasicSasl(boolean z) throws Exception {
        RpcHandler rpcHandler = (RpcHandler) Mockito.mock(RpcHandler.class);
        ((RpcHandler) Mockito.doAnswer(new Answer<Void>() { // from class: org.apache.spark.network.sasl.SparkSaslSuite.2
            /* renamed from: answer, reason: merged with bridge method [inline-methods] */
            public Void m3answer(InvocationOnMock invocationOnMock) {
                byte[] bArr = (byte[]) invocationOnMock.getArguments()[1];
                RpcResponseCallback rpcResponseCallback = (RpcResponseCallback) invocationOnMock.getArguments()[2];
                Assert.assertEquals("Ping", new String(bArr, StandardCharsets.UTF_8));
                rpcResponseCallback.onSuccess("Pong".getBytes(StandardCharsets.UTF_8));
                return null;
            }
        }).when(rpcHandler)).receive((TransportClient) Mockito.any(TransportClient.class), (byte[]) Mockito.any(byte[].class), (RpcResponseCallback) Mockito.any(RpcResponseCallback.class));
        SaslTestCtx saslTestCtx = new SaslTestCtx(rpcHandler, z, false);
        try {
            Assert.assertEquals("Pong", new String(saslTestCtx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), TimeUnit.SECONDS.toMillis(10L)), StandardCharsets.UTF_8));
            saslTestCtx.close();
        } catch (Throwable th) {
            saslTestCtx.close();
            throw th;
        }
    }

    @Test
    public void testEncryptedMessage() throws Exception {
        SaslEncryptionBackend saslEncryptionBackend = (SaslEncryptionBackend) Mockito.mock(SaslEncryptionBackend.class);
        byte[] bArr = new byte[1024];
        new Random().nextBytes(bArr);
        Mockito.when(saslEncryptionBackend.wrap((byte[]) Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt())).thenReturn(bArr);
        ByteBuf buffer = Unpooled.buffer();
        try {
            buffer.writeBytes(bArr);
            ByteArrayWritableChannel byteArrayWritableChannel = new ByteArrayWritableChannel(32);
            SaslEncryption.EncryptedMessage encryptedMessage = new SaslEncryption.EncryptedMessage(saslEncryptionBackend, buffer, 1024);
            long transferTo = encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transfered());
            Assert.assertTrue(transferTo < ((long) bArr.length));
            Assert.assertTrue(transferTo > 0);
            Assert.assertEquals(0L, encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transfered()));
            byteArrayWritableChannel.reset();
            Assert.assertEquals(1L, encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transfered()));
            for (int i = 0; i < (bArr.length / 32) - 2; i++) {
                byteArrayWritableChannel.reset();
                Assert.assertEquals(1L, encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transfered()));
            }
            byteArrayWritableChannel.reset();
            long transferTo2 = encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transfered());
            Assert.assertTrue("Unexpected count: " + transferTo2, transferTo2 > 1 && transferTo2 < ((long) bArr.length));
            Assert.assertEquals(bArr.length, encryptedMessage.transfered());
            buffer.release();
        } catch (Throwable th) {
            buffer.release();
            throw th;
        }
    }

    @Test
    public void testEncryptedMessageChunking() throws Exception {
        File createTempFile = File.createTempFile("sasltest", ".txt");
        try {
            TransportConf transportConf = new TransportConf(new SystemPropertyConfigProvider());
            byte[] bArr = new byte[8192];
            new Random().nextBytes(bArr);
            Files.write(bArr, createTempFile);
            SaslEncryptionBackend saslEncryptionBackend = (SaslEncryptionBackend) Mockito.mock(SaslEncryptionBackend.class);
            Mockito.when(saslEncryptionBackend.wrap((byte[]) Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt())).thenReturn(bArr);
            SaslEncryption.EncryptedMessage encryptedMessage = new SaslEncryption.EncryptedMessage(saslEncryptionBackend, new FileSegmentManagedBuffer(transportConf, createTempFile, 0L, createTempFile.length()).convertToNetty(), bArr.length / 8);
            ByteArrayWritableChannel byteArrayWritableChannel = new ByteArrayWritableChannel(bArr.length);
            while (encryptedMessage.transfered() < encryptedMessage.count()) {
                byteArrayWritableChannel.reset();
                encryptedMessage.transferTo(byteArrayWritableChannel, encryptedMessage.transfered());
            }
            ((SaslEncryptionBackend) Mockito.verify(saslEncryptionBackend, Mockito.times(8))).wrap((byte[]) Mockito.any(byte[].class), Mockito.anyInt(), Mockito.anyInt());
            createTempFile.delete();
        } catch (Throwable th) {
            createTempFile.delete();
            throw th;
        }
    }

    @Test
    public void testFileRegionEncryption() throws Exception {
        System.setProperty("spark.network.sasl.maxEncryptedBlockSize", "1k");
        final AtomicReference atomicReference = new AtomicReference();
        final File createTempFile = File.createTempFile("sasltest", ".txt");
        SaslTestCtx saslTestCtx = null;
        try {
            final TransportConf transportConf = new TransportConf(new SystemPropertyConfigProvider());
            StreamManager streamManager = (StreamManager) Mockito.mock(StreamManager.class);
            Mockito.when(streamManager.getChunk(Mockito.anyLong(), Mockito.anyInt())).thenAnswer(new Answer<ManagedBuffer>() { // from class: org.apache.spark.network.sasl.SparkSaslSuite.3
                /* renamed from: answer, reason: merged with bridge method [inline-methods] */
                public ManagedBuffer m4answer(InvocationOnMock invocationOnMock) {
                    return new FileSegmentManagedBuffer(transportConf, createTempFile, 0L, createTempFile.length());
                }
            });
            RpcHandler rpcHandler = (RpcHandler) Mockito.mock(RpcHandler.class);
            Mockito.when(rpcHandler.getStreamManager()).thenReturn(streamManager);
            byte[] bArr = new byte[8192];
            new Random().nextBytes(bArr);
            Files.write(bArr, createTempFile);
            saslTestCtx = new SaslTestCtx(rpcHandler, true, false);
            final Object obj = new Object();
            ChunkReceivedCallback chunkReceivedCallback = (ChunkReceivedCallback) Mockito.mock(ChunkReceivedCallback.class);
            ((ChunkReceivedCallback) Mockito.doAnswer(new Answer<Void>() { // from class: org.apache.spark.network.sasl.SparkSaslSuite.4
                /* renamed from: answer, reason: merged with bridge method [inline-methods] */
                public Void m5answer(InvocationOnMock invocationOnMock) {
                    atomicReference.set((ManagedBuffer) invocationOnMock.getArguments()[1]);
                    ((ManagedBuffer) atomicReference.get()).retain();
                    synchronized (obj) {
                        obj.notifyAll();
                    }
                    return null;
                }
            }).when(chunkReceivedCallback)).onSuccess(Mockito.anyInt(), (ManagedBuffer) Mockito.any(ManagedBuffer.class));
            synchronized (obj) {
                saslTestCtx.client.fetchChunk(0L, 0, chunkReceivedCallback);
                obj.wait(10000L);
            }
            ((ChunkReceivedCallback) Mockito.verify(chunkReceivedCallback, Mockito.times(1))).onSuccess(Mockito.anyInt(), (ManagedBuffer) Mockito.any(ManagedBuffer.class));
            ((ChunkReceivedCallback) Mockito.verify(chunkReceivedCallback, Mockito.never())).onFailure(Mockito.anyInt(), (Throwable) Mockito.any(Throwable.class));
            Assert.assertTrue(Arrays.equals(bArr, ByteStreams.toByteArray(((ManagedBuffer) atomicReference.get()).createInputStream())));
            createTempFile.delete();
            if (saslTestCtx != null) {
                saslTestCtx.close();
            }
            if (atomicReference.get() != null) {
                ((ManagedBuffer) atomicReference.get()).release();
            }
            System.clearProperty("spark.network.sasl.maxEncryptedBlockSize");
        } catch (Throwable th) {
            createTempFile.delete();
            if (saslTestCtx != null) {
                saslTestCtx.close();
            }
            if (atomicReference.get() != null) {
                ((ManagedBuffer) atomicReference.get()).release();
            }
            System.clearProperty("spark.network.sasl.maxEncryptedBlockSize");
            throw th;
        }
    }

    @Test
    public void testServerAlwaysEncrypt() throws Exception {
        System.setProperty("spark.network.sasl.serverAlwaysEncrypt", "true");
        SaslTestCtx saslTestCtx = null;
        try {
            try {
                saslTestCtx = new SaslTestCtx((RpcHandler) Mockito.mock(RpcHandler.class), false, false);
                Assert.fail("Should have failed to connect without encryption.");
                if (saslTestCtx != null) {
                    saslTestCtx.close();
                }
                System.clearProperty("spark.network.sasl.serverAlwaysEncrypt");
            } catch (Exception e) {
                Assert.assertTrue(e.getCause() instanceof SaslException);
                if (saslTestCtx != null) {
                    saslTestCtx.close();
                }
                System.clearProperty("spark.network.sasl.serverAlwaysEncrypt");
            }
        } catch (Throwable th) {
            if (saslTestCtx != null) {
                saslTestCtx.close();
            }
            System.clearProperty("spark.network.sasl.serverAlwaysEncrypt");
            throw th;
        }
    }

    @Test
    public void testDataEncryptionIsActuallyEnabled() throws Exception {
        SaslTestCtx saslTestCtx = null;
        try {
            try {
                saslTestCtx = new SaslTestCtx((RpcHandler) Mockito.mock(RpcHandler.class), true, true);
                saslTestCtx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8), TimeUnit.SECONDS.toMillis(10L));
                Assert.fail("Should have failed to send RPC to server.");
                if (saslTestCtx != null) {
                    saslTestCtx.close();
                }
            } catch (Exception e) {
                Assert.assertFalse(e.getCause() instanceof TimeoutException);
                if (saslTestCtx != null) {
                    saslTestCtx.close();
                }
            }
        } catch (Throwable th) {
            if (saslTestCtx != null) {
                saslTestCtx.close();
            }
            throw th;
        }
    }
}
