package com.github.hepeng86.mybatisplus.encrypt.plugin;

import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.enums.SqlKeyword;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.sql.StringEscape;
import com.github.hepeng86.mybatisplus.encrypt.Encrypt;
import com.github.hepeng86.mybatisplus.encrypt.properties.EncryptConfigProperties;
import com.github.hepeng86.mybatisplus.encrypt.util.ReflectionUtils;
import com.github.hepeng86.mybatisplus.encrypt.model.SensitiveField;
import com.github.hepeng86.mybatisplus.encrypt.util.StringUtil;
import java.sql.Connection;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod.ParamMap;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;

/**
 * Statement拦截
 * @author hepeng
 * @since 2024-07-10
 */
@Intercepts({
        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}),
})
@Slf4j
public class EncryptPrepareInterceptor extends AbstractInterceptor {

    public EncryptPrepareInterceptor(Encrypt encrypt, EncryptConfigProperties encryptConfigProperties) {
        super(encrypt, encryptConfigProperties);
    }

    @SuppressWarnings("unchecked")
    @Override
    public Object intercept(Invocation invocation) throws Exception {
        if (isMixedQueryMode()) {
            long start = System.currentTimeMillis();
            try {
                StatementHandler statementHandler = (StatementHandler)invocation.getTarget();
                BoundSql boundSql = statementHandler.getBoundSql();
                Object parameterObject = boundSql.getParameterObject();
                if (parameterObject instanceof ParamMap) {
                    ParamMap<Object> paramMap = (ParamMap<Object>) parameterObject;
                    if (paramMap.containsKey(Constants.WRAPPER) && Objects.nonNull(paramMap.get(Constants.WRAPPER))) {
                        AbstractWrapper<Object, ?, ?> wrapper = (AbstractWrapper<Object, ?, ?>) paramMap.get(Constants.WRAPPER);
                        if (Objects.nonNull(wrapper.getEntityClass()) && CollectionUtils.isNotEmpty(wrapper.getExpression().getNormal())) {
                            handleParameters(wrapper, boundSql);
                        }
                    }
                }
            } catch (Exception e) {
                log.error("prepare encrypt fail", e);
            }

            long end = System.currentTimeMillis();
            long cost = end - start;
            if (cost > 10) {
                log.info("prepare encrypt cost:{}ms", cost);
            }
        }

        return invocation.proceed();
    }

    protected void setParamNameValuePairs(int seq, String value, AbstractWrapper<?, ?, ?> wrapper,
            SensitiveField sensitiveField, BoundSql boundSql, SqlKeyword sqlKeyword) throws IllegalAccessException, NoSuchFieldException {
        String genName = Constants.WRAPPER_PARAM + seq;

        if (Arrays.asList(SqlKeyword.IN, SqlKeyword.NOT_IN).contains(sqlKeyword)) {
            String property = String.format("ew.paramNameValuePairs.%s", genName);
            List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
            for (int i = 0, size = parameterMappings.size(); i < size; i++) {
                if (parameterMappings.get(i).getProperty().equals(property)) {
                    String newSql = StringUtil.replaceNthOccurrenceReverse(boundSql.getSql(), "?", String.format("%s, ?", StringEscape.escapeString(value)), (size - i));
                    ReflectionUtils.setField(boundSql, "sql", newSql);
                    break;
                }
            }
        }
    }

    @Override
    protected String convert(String value) {
        return encrypt.encrypt(value);
    }
}
