/*
 * Decompiled with CFR 0.152.
 */
package net.snowflake.client.jdbc.cloud.storage;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Base64;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import net.snowflake.client.jdbc.MatDesc;
import net.snowflake.client.jdbc.cloud.storage.GcmEncryptionProvider;
import net.snowflake.client.jdbc.cloud.storage.SnowflakeStorageClient;
import net.snowflake.client.jdbc.cloud.storage.StorageObjectMetadata;
import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

public class GcmEncryptionProviderTest {
    private final SecureRandom random = new SecureRandom();
    private final ArgumentCaptor<StorageObjectMetadata> storageObjectMetadataArgumentCaptor = ArgumentCaptor.forClass(StorageObjectMetadata.class);
    private final ArgumentCaptor<MatDesc> matDescArgumentCaptor = ArgumentCaptor.forClass(MatDesc.class);
    private final ArgumentCaptor<byte[]> dataIvDataArgumentCaptor = ArgumentCaptor.forClass(byte[].class);
    private final ArgumentCaptor<byte[]> keyIvDataArgumentCaptor = ArgumentCaptor.forClass(byte[].class);
    private final ArgumentCaptor<byte[]> encKeyArgumentCaptor = ArgumentCaptor.forClass(byte[].class);
    private final ArgumentCaptor<byte[]> keyAadArgumentCaptor = ArgumentCaptor.forClass(byte[].class);
    private final ArgumentCaptor<byte[]> dataAadArgumentCaptor = ArgumentCaptor.forClass(byte[].class);
    private final ArgumentCaptor<Long> contentLengthArgumentCaptor = ArgumentCaptor.forClass(Long.class);
    private final StorageObjectMetadata meta = (StorageObjectMetadata)Mockito.mock(StorageObjectMetadata.class);
    private final SnowflakeStorageClient storageClient = (SnowflakeStorageClient)Mockito.mock(SnowflakeStorageClient.class);
    private final String queryStageMasterKey = Base64.getEncoder().encodeToString(this.random.generateSeed(16));
    private final RemoteStoreFileEncryptionMaterial encMat = new RemoteStoreFileEncryptionMaterial();
    byte[] plainText = "the quick brown fox jumps over the lazy dog".getBytes(StandardCharsets.UTF_8);
    byte[] dataAad = "data aad".getBytes(StandardCharsets.UTF_8);
    byte[] keyAad = "key aad".getBytes(StandardCharsets.UTF_8);

    @BeforeEach
    public void setUp() {
        this.encMat.setQueryStageMasterKey(this.queryStageMasterKey);
        this.encMat.setSmkId(123L);
        this.encMat.setQueryId("query-id");
    }

    @Test
    public void testEncryptAndDecryptStreamWithoutAad() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, null, null);
        InputStream inputStream = this.decryptStream(cipherText, null, null);
        byte[] decryptedPlainText = IOUtils.toByteArray((InputStream)inputStream);
        Assertions.assertArrayEquals((byte[])this.plainText, (byte[])decryptedPlainText);
    }

    @Test
    public void testEncryptAndDecryptStreamWithAad() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, this.dataAad, this.keyAad);
        InputStream inputStream = this.decryptStream(cipherText, this.dataAad, this.keyAad);
        byte[] decryptedPlainText = IOUtils.toByteArray((InputStream)inputStream);
        Assertions.assertArrayEquals((byte[])this.plainText, (byte[])decryptedPlainText);
    }

    @Test
    public void testDecryptStreamWithInvalidKeyAad() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, this.dataAad, this.keyAad);
        Assertions.assertThrows(AEADBadTagException.class, () -> this.decryptStream(cipherText, this.dataAad, new byte[]{97}));
    }

    @Test
    public void testDecryptStreamWithInvalidDataAad() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, this.dataAad, this.keyAad);
        IOException ioException = (IOException)Assertions.assertThrows(IOException.class, () -> IOUtils.toByteArray((InputStream)this.decryptStream(cipherText, new byte[]{97}, this.keyAad)));
        Assertions.assertEquals(ioException.getCause().getClass(), AEADBadTagException.class);
    }

    @Test
    public void testDecryptStreamWithInvalidCipherText() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, this.dataAad, this.keyAad);
        cipherText[0] = (byte)((cipherText[0] + 1) % 255);
        IOException ioException = (IOException)Assertions.assertThrows(IOException.class, () -> IOUtils.toByteArray((InputStream)this.decryptStream(cipherText, this.dataAad, this.keyAad)));
        Assertions.assertEquals(ioException.getCause().getClass(), AEADBadTagException.class);
    }

    @Test
    public void testDecryptStreamWithInvalidTag() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, this.dataAad, this.keyAad);
        cipherText[cipherText.length - 1] = (byte)((cipherText[cipherText.length - 1] + 1) % 255);
        IOException ioException = (IOException)Assertions.assertThrows(IOException.class, () -> IOUtils.toByteArray((InputStream)this.decryptStream(cipherText, this.dataAad, this.keyAad)));
        Assertions.assertEquals(ioException.getCause().getClass(), AEADBadTagException.class);
    }

    @Test
    public void testDecryptStreamWithInvalidKey() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, this.dataAad, this.keyAad);
        byte[] encryptedKey = (byte[])this.encKeyArgumentCaptor.getValue();
        encryptedKey[0] = (byte)((encryptedKey[0] + 1) % 255);
        Assertions.assertThrows(AEADBadTagException.class, () -> IOUtils.toByteArray((InputStream)GcmEncryptionProvider.decryptStream((InputStream)new ByteArrayInputStream(cipherText), (String)Base64.getEncoder().encodeToString(encryptedKey), (String)Base64.getEncoder().encodeToString((byte[])this.dataIvDataArgumentCaptor.getValue()), (String)Base64.getEncoder().encodeToString((byte[])this.keyIvDataArgumentCaptor.getValue()), (RemoteStoreFileEncryptionMaterial)this.encMat, (String)(this.dataAad == null ? "" : Base64.getEncoder().encodeToString(this.dataAad)), (String)(this.keyAad == null ? "" : Base64.getEncoder().encodeToString(this.keyAad)))));
    }

    @Test
    public void testDecryptStreamWithInvalidDataIV() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, this.dataAad, this.keyAad);
        byte[] dataIvBase64 = (byte[])this.dataIvDataArgumentCaptor.getValue();
        dataIvBase64[0] = (byte)((dataIvBase64[0] + 1) % 255);
        IOException ioException = (IOException)Assertions.assertThrows(IOException.class, () -> IOUtils.toByteArray((InputStream)GcmEncryptionProvider.decryptStream((InputStream)new ByteArrayInputStream(cipherText), (String)Base64.getEncoder().encodeToString((byte[])this.encKeyArgumentCaptor.getValue()), (String)Base64.getEncoder().encodeToString(dataIvBase64), (String)Base64.getEncoder().encodeToString((byte[])this.keyIvDataArgumentCaptor.getValue()), (RemoteStoreFileEncryptionMaterial)this.encMat, (String)(this.dataAad == null ? "" : Base64.getEncoder().encodeToString(this.dataAad)), (String)(this.keyAad == null ? "" : Base64.getEncoder().encodeToString(this.keyAad)))));
        Assertions.assertEquals(ioException.getCause().getClass(), AEADBadTagException.class);
    }

    @Test
    public void testDecryptStreamWithInvalidKeyIV() throws Exception {
        ByteArrayInputStream plainTextStream = new ByteArrayInputStream(this.plainText);
        byte[] cipherText = this.encryptStream(plainTextStream, this.dataAad, this.keyAad);
        byte[] keyIvBase64 = (byte[])this.keyIvDataArgumentCaptor.getValue();
        keyIvBase64[0] = (byte)((keyIvBase64[0] + 1) % 255);
        Assertions.assertThrows(AEADBadTagException.class, () -> IOUtils.toByteArray((InputStream)GcmEncryptionProvider.decryptStream((InputStream)new ByteArrayInputStream(cipherText), (String)Base64.getEncoder().encodeToString((byte[])this.encKeyArgumentCaptor.getValue()), (String)Base64.getEncoder().encodeToString((byte[])this.dataIvDataArgumentCaptor.getValue()), (String)Base64.getEncoder().encodeToString(keyIvBase64), (RemoteStoreFileEncryptionMaterial)this.encMat, (String)(this.dataAad == null ? "" : Base64.getEncoder().encodeToString(this.dataAad)), (String)(this.keyAad == null ? "" : Base64.getEncoder().encodeToString(this.keyAad)))));
    }

    private byte[] encryptStream(InputStream plainTextStream, byte[] dataAad, byte[] keyAad) throws InvalidKeyException, InvalidAlgorithmParameterException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException, IOException {
        InputStream encrypted = GcmEncryptionProvider.encrypt((StorageObjectMetadata)this.meta, (long)this.plainText.length, (InputStream)plainTextStream, (RemoteStoreFileEncryptionMaterial)this.encMat, (SnowflakeStorageClient)this.storageClient, (byte[])dataAad, (byte[])keyAad);
        byte[] cipherText = IOUtils.toByteArray((InputStream)encrypted);
        this.captureKeysAndIvs();
        return cipherText;
    }

    private InputStream decryptStream(byte[] cipherText, byte[] dataAad, byte[] keyAad) throws InvalidKeyException, BadPaddingException, IllegalBlockSizeException, InvalidAlgorithmParameterException, NoSuchPaddingException, NoSuchAlgorithmException {
        return GcmEncryptionProvider.decryptStream((InputStream)new ByteArrayInputStream(cipherText), (String)Base64.getEncoder().encodeToString((byte[])this.encKeyArgumentCaptor.getValue()), (String)Base64.getEncoder().encodeToString((byte[])this.dataIvDataArgumentCaptor.getValue()), (String)Base64.getEncoder().encodeToString((byte[])this.keyIvDataArgumentCaptor.getValue()), (RemoteStoreFileEncryptionMaterial)this.encMat, (String)(dataAad == null ? "" : Base64.getEncoder().encodeToString(dataAad)), (String)(keyAad == null ? "" : Base64.getEncoder().encodeToString(keyAad)));
    }

    @Test
    public void testEncryptAndDecryptFileWithoutAad() throws Exception {
        File tempFile = Files.createTempFile("encryption", "", new FileAttribute[0]).toFile();
        tempFile.deleteOnExit();
        ByteArrayInputStream encrypted = new ByteArrayInputStream(this.encryptStream(new ByteArrayInputStream(this.plainText), null, null));
        FileUtils.writeByteArrayToFile((File)tempFile, (byte[])IOUtils.toByteArray((InputStream)encrypted));
        this.captureKeysAndIvs();
        this.decryptFile(tempFile, null, null);
        byte[] decryptedCipherText = FileUtils.readFileToByteArray((File)tempFile);
        Assertions.assertArrayEquals((byte[])this.plainText, (byte[])decryptedCipherText);
    }

    @Test
    public void testEncryptAndDecryptFileWithAad() throws Exception {
        File tempFile = Files.createTempFile("encryption", "", new FileAttribute[0]).toFile();
        tempFile.deleteOnExit();
        ByteArrayInputStream encrypted = new ByteArrayInputStream(this.encryptStream(new ByteArrayInputStream(this.plainText), this.dataAad, this.keyAad));
        FileUtils.writeByteArrayToFile((File)tempFile, (byte[])IOUtils.toByteArray((InputStream)encrypted));
        this.captureKeysAndIvs();
        this.decryptFile(tempFile, this.dataAad, this.keyAad);
        byte[] decryptedCipherText = FileUtils.readFileToByteArray((File)tempFile);
        Assertions.assertArrayEquals((byte[])this.plainText, (byte[])decryptedCipherText);
    }

    private void decryptFile(File tempFile, byte[] dataAad, byte[] keyAad) throws InvalidKeyException, IllegalBlockSizeException, BadPaddingException, InvalidAlgorithmParameterException, IOException, NoSuchPaddingException, NoSuchAlgorithmException {
        GcmEncryptionProvider.decryptFile((File)tempFile, (String)Base64.getEncoder().encodeToString((byte[])this.encKeyArgumentCaptor.getValue()), (String)Base64.getEncoder().encodeToString((byte[])this.dataIvDataArgumentCaptor.getValue()), (String)Base64.getEncoder().encodeToString((byte[])this.keyIvDataArgumentCaptor.getValue()), (RemoteStoreFileEncryptionMaterial)this.encMat, (String)(this.dataAadArgumentCaptor.getValue() == null ? "" : Base64.getEncoder().encodeToString((byte[])this.dataAadArgumentCaptor.getValue())), (String)(this.keyAadArgumentCaptor.getValue() == null ? "" : Base64.getEncoder().encodeToString((byte[])this.keyAadArgumentCaptor.getValue())));
    }

    private void captureKeysAndIvs() {
        ((SnowflakeStorageClient)Mockito.verify((Object)this.storageClient)).addEncryptionMetadataForGcm((StorageObjectMetadata)this.storageObjectMetadataArgumentCaptor.capture(), (MatDesc)this.matDescArgumentCaptor.capture(), (byte[])this.encKeyArgumentCaptor.capture(), (byte[])this.dataIvDataArgumentCaptor.capture(), (byte[])this.keyIvDataArgumentCaptor.capture(), (byte[])this.keyAadArgumentCaptor.capture(), (byte[])this.dataAadArgumentCaptor.capture(), ((Long)this.contentLengthArgumentCaptor.capture()).longValue());
    }
}

