package cn.wjee.commons.crypto;

import cn.wjee.commons.collection.Tuple2;
import cn.wjee.commons.enums.ApiStatusEnum;
import cn.wjee.commons.exception.BizException;
import cn.wjee.commons.exception.BusinessException;
import cn.wjee.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;

/**
 * RSA 加密解密工具类
 * <p>
 * 1. 生成秘钥库
 * RSA::generateJks
 * <p>
 * 2. keytool -list -rfc -keystore wjee.jks
 * 执行查看公钥的命令：
 * keytool -list -rfc --keystore wjee.jks | openssl x509 -inform pem -pubkey
 * <p>
 * 3. 前端JS加密库 jsencrypt.js
 * </p>
 * <p>
 * 1024bit/PKCS     117               128
 * 2048bit/PKCS     245               256
 *
 * @author wjee
 * @version $Id: RSAUtils.java, v 0.1 2015年11月8日 下午4:54:10 wjee Exp $
 */
public class RSA {
    /**
     * 日志
     */
    private static final Logger log = LoggerFactory.getLogger(RSA.class);
    /**
     * 公钥
     */
    public static final String PUBLIC_KEY = "PublicKey";
    /**
     * 私钥
     */
    public static final String PRIVATE_KEY = "PrivateKey";
    /**
     * RSA加解密类型
     */
    private RsaType rsaType = RsaType.RSA;

    public RSA(RsaType rsaType) {
        this.rsaType = rsaType;
    }

    public static RSA newRsa() {
        return new RSA(RsaType.RSA);
    }

    public static RSA newRsa2() {
        return new RSA(RsaType.RSA2);
    }

    /**
     * 创建RSA公钥和私钥对
     * <p>
     * 公钥：RSAUtils.PUBLIC_KEY
     * 私钥：RSAUtils.PRIVATE_KEY
     *
     * @return Map
     */
    public Tuple2<String, String> newRsaKeys() {
        try {
            KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(rsaType.getKeyAlgorithm());
            keyPairGen.initialize(rsaType.getKeySize());
            KeyPair keyPair = keyPairGen.generateKeyPair();
            RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
            RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();

            String publicKeyStr = Base64.getEncoder().encodeToString(publicKey.getEncoded());
            String privateKeyStr = Base64.getEncoder().encodeToString(privateKey.getEncoded());
            log.info("RSA::PrivateKey::{}", privateKeyStr);
            log.info("RSA::PublicKey::{}", publicKeyStr);
            return Tuple2.of(privateKeyStr, publicKeyStr);
        } catch (NoSuchAlgorithmException e) {
            throw new BizException("newRsaKeys fail", e);
        }
    }

    /**
     * 用私钥解密
     *
     * @param value 密文
     * @param key   密钥
     * @return String
     */
    public String decryptByPrivateKey(String value, String key) {
        try {
            byte[] data = Base64.getDecoder().decode(value);
            // 对私钥解密
            byte[] keyBytes = Base64.getDecoder().decode(key);
            PKCS8EncodedKeySpec pkcs8EncodedKeySpec = new PKCS8EncodedKeySpec(keyBytes);
            KeyFactory keyFactory = KeyFactory.getInstance(rsaType.getKeyAlgorithm());
            Key privateKey = keyFactory.generatePrivate(pkcs8EncodedKeySpec);
            // 对数据解密
            Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
            cipher.init(Cipher.DECRYPT_MODE, privateKey);
            return EncodeUtils.getString(doFinalBySegment(cipher, data, false));
        } catch (NoSuchAlgorithmException | InvalidKeySpecException | NoSuchPaddingException | InvalidKeyException e) {
            throw new BizException("decryptByPrivateKey fail", e);
        }
    }

    /**
     * 用公钥加密
     *
     * @param value 要加密数据
     * @param key   密钥
     * @return byte[]
     */
    public String encryptByPublicKey(String value, String key) {
        try {
            byte[] data = EncodeUtils.getBytes(value);
            // 对公钥解密
            byte[] keyBytes = Base64.getDecoder().decode(key);
            // 取公钥
            X509EncodedKeySpec x509EncodedKeySpec = new X509EncodedKeySpec(keyBytes);
            KeyFactory keyFactory = KeyFactory.getInstance(rsaType.getKeyAlgorithm());
            Key publicKey = keyFactory.generatePublic(x509EncodedKeySpec);
            // 对数据解密
            Cipher cipher = Cipher.getInstance(keyFactory.getAlgorithm());
            cipher.init(Cipher.ENCRYPT_MODE, publicKey);
            return Base64.getEncoder().encodeToString(doFinalBySegment(cipher, data, true));
        } catch (NoSuchAlgorithmException | InvalidKeySpecException | InvalidKeyException | NoSuchPaddingException e) {
            throw new BizException("encryptByPublicKey fail", e);
        }
    }

    /**
     * 用私钥对信息生成数字签名
     *
     * @param data       加密数据
     * @param privateKey 私钥
     * @return String
     */
    public String sign(String data, String privateKey) {
        try {
            byte[] keyBytes = Base64.getDecoder().decode(privateKey);
            PKCS8EncodedKeySpec pkcs8EncodedKeySpec = new PKCS8EncodedKeySpec(keyBytes);
            KeyFactory keyFactory = KeyFactory.getInstance(rsaType.getKeyAlgorithm());
            PrivateKey privateKey2 = keyFactory.generatePrivate(pkcs8EncodedKeySpec);
            // 用私钥对信息生成数字签名
            Signature signature = Signature.getInstance(rsaType.getSignatureAlgorithm());
            signature.initSign(privateKey2);
            signature.update(EncodeUtils.getBytes(data));
            return Base64.getEncoder().encodeToString(signature.sign());
        } catch (NoSuchAlgorithmException | InvalidKeySpecException | SignatureException | InvalidKeyException e) {
            throw new BizException("sign fail", e);
        }
    }

    /**
     * 用公钥验证数据签名
     *
     * @param data      报文数据
     * @param sign      签名
     * @param publicKey 公钥
     * @return boolean
     */
    public boolean verifySign(String data, String sign, String publicKey) {
        try {
            byte[] keyBytes = Base64.getDecoder().decode(publicKey);
            X509EncodedKeySpec x509EncodedKeySpec = new X509EncodedKeySpec(keyBytes);
            KeyFactory keyFactory = KeyFactory.getInstance(rsaType.getKeyAlgorithm());
            PublicKey publicKey2 = keyFactory.generatePublic(x509EncodedKeySpec);
            // 用公钥对数字签名验证
            Signature signature = Signature.getInstance(rsaType.getSignatureAlgorithm());
            signature.initVerify(publicKey2);
            signature.update(EncodeUtils.getBytes(data));
            return signature.verify(Base64.getDecoder().decode(sign));
        } catch (NoSuchAlgorithmException | InvalidKeySpecException | InvalidKeyException | SignatureException e) {
            throw new BizException("verifySign fail", e);
        }
    }

    /**
     * doFinal By Segment since the max block size is limited
     */
    private byte[] doFinalBySegment(Cipher cipher, byte[] source, boolean isEncode) {
        try {
            try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
                int blockSize = isEncode ? rsaType.getMaxEncryptBlock() : rsaType.getMaxDecryptBlock();
                if (source.length <= blockSize) {
                    return cipher.doFinal(source);
                }
                int offsetIndex = 0;
                int offset = 0;
                int sourceLength = source.length;
                while (sourceLength - offset > 0) {
                    int size = Math.min(sourceLength - offset, blockSize);
                    byte[] buffer = cipher.doFinal(source, offset, size);
                    out.write(buffer, 0, buffer.length);
                    offsetIndex++;
                    offset = offsetIndex * blockSize;
                }
                return out.toByteArray();
            }
        } catch (IOException | IllegalBlockSizeException | BadPaddingException e) {
            throw new BizException("doFinalBySegment fail", e);
        }
    }

    /**
     * 生成JKS秘钥库
     *
     * @param keyStoreFullSavePath JKS保存目录(如：/temp/hello.jks)
     * @param keyPass              密码
     * @throws IOException ex
     */
    public void generateJks(String keyStoreFullSavePath, String keyPass) throws IOException {
        if (!StringUtils.endsWith(keyStoreFullSavePath, ".jks")) {
            throw new BusinessException(ApiStatusEnum.FAILURE_500.getCode(), "保存路径不正确");
        }
        String fileName = new File(keyStoreFullSavePath).getName();
        String alias = StringUtils.substring(fileName, 0, fileName.lastIndexOf("."));

        String execShell = "keytool -genkey -v -alias " + alias
            + " -keyalg RSA -keysize " + rsaType.getKeySize()
            + " -validity 36500 -keystore " + keyStoreFullSavePath
            + " -keypass " + keyPass
            + " -storepass " + keyPass
            + " -dname \"CN=CN,OU=CN,O=CN,L=CN,ST=CN,C=zh_CN\" -deststoretype pkcs12";
        Runtime.getRuntime().exec(execShell);
    }
}
