package com.github.hepeng86.mybatisplus.encrypt.plugin;

import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.enums.SqlKeyword;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.github.hepeng86.mybatisplus.encrypt.Encrypt;
import com.github.hepeng86.mybatisplus.encrypt.properties.EncryptConfigProperties;
import com.github.hepeng86.mybatisplus.encrypt.constant.EncryptConstants;
import com.github.hepeng86.mybatisplus.encrypt.util.ReflectionUtils;
import com.github.hepeng86.mybatisplus.encrypt.model.SensitiveField;
import java.sql.PreparedStatement;
import java.util.Objects;
import java.util.regex.Matcher;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod.ParamMap;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.springframework.beans.BeanUtils;

/**
 * Parameter拦截
 * @author hepeng
 * @since 2024-07-10
 */
@Intercepts({
        @Signature(type = ParameterHandler.class, method = "setParameters", args = {PreparedStatement.class}),
})
@Slf4j
public class EncryptParameterInterceptor extends AbstractInterceptor {

    public EncryptParameterInterceptor(Encrypt encrypt, EncryptConfigProperties encryptConfigProperties) {
        super(encrypt, encryptConfigProperties);
    }

    @SuppressWarnings("unchecked")
    @Override
    public Object intercept(Invocation invocation) throws Exception {
        long start = System.currentTimeMillis();
        ParameterHandler parameterHandler = (ParameterHandler)invocation.getTarget();
        Object parameterObject = parameterHandler.getParameterObject();
        if (Objects.isNull(parameterObject)) {
            return invocation.proceed();
        }

        try {
            if (parameterObject instanceof ParamMap) {
                ParamMap<Object> paramMap = (ParamMap<Object>) parameterObject;
                if (paramMap.containsKey(Constants.WRAPPER) && Objects.nonNull(paramMap.get(Constants.WRAPPER))) {
                    AbstractWrapper<Object, ?, ?> wrapper = (AbstractWrapper<Object, ?, ?>) paramMap.get(Constants.WRAPPER);
                    if (Objects.nonNull(wrapper.getEntity())) {
                        wrapper.setEntity(clone(wrapper.getEntity()));
                        handleParameters(wrapper.getEntity());
                    }

                    if (Objects.nonNull(wrapper.getEntityClass()) && StringUtils.isNotBlank(wrapper.getSqlSet())) {
                        handleParameters(wrapper);
                    }

                    if (Objects.nonNull(wrapper.getEntityClass()) && CollectionUtils.isNotEmpty(wrapper.getExpression().getNormal())) {
                        handleParameters(wrapper, null);
                    }
                }

                if (paramMap.containsKey(Constants.ENTITY) && Objects.nonNull(paramMap.get(Constants.ENTITY))) {
                    paramMap.put(Constants.ENTITY, clone(paramMap.get(Constants.ENTITY)));
                    handleParameters(paramMap.get(Constants.ENTITY));
                }
            } else if (!CollectionUtils.isEmpty(ReflectionUtils.getEncryptColumnNameAndFieldMapFromCache(parameterObject.getClass()))) {
                ReflectionUtils.setField(parameterHandler, "parameterObject", clone(parameterObject));
                handleParameters(parameterHandler.getParameterObject());
            }
        } catch (Exception e) {
            log.error("parameter encrypt fail", e);
        }

        long end = System.currentTimeMillis();
        long cost = end - start;
        if (cost > 10) {
            log.info("parameter encrypt cost:{}ms", cost);
        }

        return invocation.proceed();
    }

    public static Object clone(Object source) {
        Object result = BeanUtils.instantiateClass(source.getClass());
        BeanUtils.copyProperties(source, result);
        return result;
    }

    protected void handleParameters(AbstractWrapper<?, ?, ?> wrapper) {
        String[] sqlSet = wrapper.getSqlSet().split(",");
        for (String sql : sqlSet) {
            Matcher matcher = EncryptConstants.SQL_SET_PATTERN.matcher(sql);
            while (matcher.find()) {
                String columnName = matcher.group(1);
                SensitiveField sensitiveField = ReflectionUtils.getSensitiveField(wrapper.getEntityClass(), columnName);
                if (Objects.isNull(sensitiveField)) {
                    continue;
                }

                String genName = Constants.WRAPPER_PARAM + matcher.group(2);
                String value = (String)wrapper.getParamNameValuePairs().get(genName);
                if (StringUtils.isBlank(value)) {
                    continue;
                }

                String encryptField = convert(value, sensitiveField.getJsonPaths());
                wrapper.getParamNameValuePairs().put(genName, encryptField);
            }
        }
    }

    protected void setParamNameValuePairs(int seq, String value, AbstractWrapper<?, ?, ?> wrapper,
            SensitiveField sensitiveField, BoundSql boundSql, SqlKeyword sqlKeyword) {
        String genName = Constants.WRAPPER_PARAM + seq;
        String encryptField = convert(value, sensitiveField.getJsonPaths());
        wrapper.getParamNameValuePairs().put(genName, encryptField);
    }

    @Override
    protected String convert(String value) {
        return encrypt.encrypt(value);
    }
}
