package com.gccloud.starter.common.utils;

import com.gccloud.starter.common.exception.GlobalException;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.xml.XmlBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.stereotype.Component;

import java.nio.charset.Charset;

/**
 * 参考 https://www.cnblogs.com/dongguangming/p/12792789.html
 *
 * @author liuchengbiao
 * @date 2020-06-16 15:10
 */
@Slf4j
@Component
public class SpringContextUtils implements ApplicationContextAware {

    private static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        if (SpringContextUtils.applicationContext == null) {
            log.info("初始化:{},通过该工具类可以实现获取其他组件", this.getClass().getSimpleName());
            SpringContextUtils.applicationContext = applicationContext;
        }
    }

    /**
     * 获取指定名称对象
     *
     * @param name
     * @return
     */
    public static Object getBean(String name) {
        try {
            return applicationContext.getBean(name);
        } catch (Exception e) {
            log.error(ExceptionUtils.getStackTrace(e));
        }
        return null;
    }

    /**
     * 获取指定类型，指定名称对象
     *
     * @param name
     * @param requiredType
     * @param <T>
     * @return
     */
    public static <T> T getBean(String name, Class<T> requiredType) {
        try {
            return applicationContext.getBean(name, requiredType);
        } catch (Exception e) {
            log.error(ExceptionUtils.getStackTrace(e));
        }
        return null;
    }

    /**
     * 根据类型获取注册的bean
     *
     * @param requiredType
     * @param <T>
     * @return
     */
    public static <T> T getBean(Class<T> requiredType) {
        try {
            return applicationContext.getBean(requiredType);
        } catch (Exception e) {
            log.error(ExceptionUtils.getStackTrace(e));
        }
        return null;
    }

    /**
     * 是否存在指定名称的bean
     *
     * @param name
     * @return
     */
    public static boolean containsBean(String name) {
        try {
            return applicationContext.containsBean(name);
        } catch (Exception e) {
            log.error(ExceptionUtils.getStackTrace(e));
        }
        return false;
    }

    /**
     * 是否是单例模式
     *
     * @param name
     * @return
     */
    public static boolean isSingleton(String name) {
        return applicationContext.isSingleton(name);
    }

    public static Class<? extends Object> getType(String name) {
        return applicationContext.getType(name);
    }

    public static void removeBeanDefinition(String beanName) {
        AutowireCapableBeanFactory autowireCapableBeanFactory = applicationContext.getAutowireCapableBeanFactory();
        if (autowireCapableBeanFactory != null) {
            DefaultListableBeanFactory defaultListableBeanFactory = (DefaultListableBeanFactory) autowireCapableBeanFactory;
            defaultListableBeanFactory.removeBeanDefinition(beanName);
        }
    }

    /**
     * 注册bean
     *
     * @param beanName
     * @param beanXmlDef
     * @return
     */
    public static boolean registBean(String beanName, String beanXmlDef) {
        DefaultListableBeanFactory beanFactory = (DefaultListableBeanFactory) applicationContext
                .getAutowireCapableBeanFactory();
        String xml = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
                + "<beans xmlns=\"http://www.springframework.org/schema/beans\""
                + "       xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\""
                + "       xsi:schemaLocation=\"http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd"
                + "       \">";
        if (StringUtils.isEmpty(beanXmlDef)) {
            throw new GlobalException("Bean的定义不能为空");
        }
        xml = xml + beanXmlDef;
        xml += "</beans>";
        XmlBeanFactory factory = new XmlBeanFactory(new ByteArrayResource(xml.getBytes(Charset.forName("UTF-8"))));
        try {
            if (containsBean(beanName)) {
                beanFactory.removeBeanDefinition(beanName);
            }
        } catch (NoSuchBeanDefinitionException e) {
            log.error(ExceptionUtils.getStackTrace(e));
        }
        try {
            beanFactory.registerBeanDefinition(beanName, factory.getMergedBeanDefinition(beanName));
            Object obj = applicationContext.getBean(beanName);
            log.info("注册bean:{},{}", beanName, obj == null ? "失败" : "成功");
            return true;
        } catch (Exception e) {
            log.error("注册bean:{}失败：", beanName);
            try {
                beanFactory.removeBeanDefinition(beanName);
            } catch (Exception e1) {
                log.error(ExceptionUtils.getStackTrace(e1));
            }
            log.error(ExceptionUtils.getStackTrace(e));
        }
        return false;
    }

}
