package cn.wjee.commons.lang;

import cn.wjee.commons.crypto.EncodeUtils;
import cn.wjee.commons.exception.BizException;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import javax.xml.XMLConstants;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpression;
import javax.xml.xpath.XPathFactory;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

/**
 * 解析XML
 *
 * @author lxn
 */
public class XmlUtils {

    private XmlUtils() {

    }

    /**
     * Get XML Document
     *
     * @param input 输入流
     * @return Document 文档
     */
    public static Document getDocument(InputStream input) {
        try {
            DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
            factory.setAttribute(XMLConstants.ACCESS_EXTERNAL_DTD, "");
            factory.setAttribute(XMLConstants.ACCESS_EXTERNAL_SCHEMA, "");
            factory.setNamespaceAware(true);
            DocumentBuilder builder = factory.newDocumentBuilder();
            return builder.parse(input);
        } catch (Exception e) {
            throw new BizException("Get Document Fail", e);
        }
    }

    /**
     * Get XML Document
     *
     * @param resource XML源
     * @return Document 文档
     */
    public static Document getDocument(String resource) {
        try (InputStream input = new ByteArrayInputStream(EncodeUtils.getBytes(resource))) {
            return getDocument(input);
        } catch (Exception e) {
            throw new BizException("Get Document By String Fail", e);
        }
    }

    /**
     * Get XML Document
     *
     * @param file 文件
     * @return Document 文档
     */
    public static Document getDocument(File file) {
        try (FileInputStream input = new FileInputStream(file)) {
            return getDocument(input);
        } catch (IOException e) {
            throw new BizException("getDocument from file fail", e);
        }
    }

    public static XPathFactory getXPathFactory() {
        return XPathFactory.newInstance();
    }

    /**
     * 查找节点
     *
     * @param document 文档
     * @param xpathExp XPath表达式
     * @return Node
     */
    public static Node getByXPath(Document document, String xpathExp) {
        try {
            XPathFactory factory = XPathFactory.newInstance();
            XPath xpath = factory.newXPath();
            XPathExpression expression = xpath.compile(xpathExp);
            Object result = expression.evaluate(document, XPathConstants.NODE);
            return (Node) result;
        } catch (Exception e) {
            throw new BizException("FindElements fail", e);
        }
    }

    /**
     * 查找节点列表
     *
     * @param document 文档
     * @param xpathExp XPath表达式
     * @return NodeList
     */
    public static NodeList getListByXPath(Document document, String xpathExp) {
        try {
            XPathFactory factory = XPathFactory.newInstance();
            XPath xpath = factory.newXPath();
            XPathExpression expression = xpath.compile(xpathExp);
            Object result = expression.evaluate(document, XPathConstants.NODESET);
            return (NodeList) result;
        } catch (Exception e) {
            throw new BizException("FindElements fail", e);
        }
    }

    /**
     * 获取根节点
     *
     * @param document 文档
     * @return Element
     */
    public static Element getRoot(Document document) {
        return document != null ? document.getDocumentElement() : null;
    }

    /**
     * 根据名字获取节点列表
     *
     * @param element       节点
     * @param attributeName 标签名
     * @return NodeList
     */
    public static String getAttribute(Element element, String attributeName) {
        return element != null ? element.getAttribute(attributeName) : null;
    }

    /**
     * 根据名字获取节点列表
     *
     * @param parentElement 节点
     * @param tagName       标签名
     * @return NodeList
     */
    public static List<Element> getChildren(Element parentElement, String tagName) {
        final NodeList nodeList = parentElement != null ? parentElement.getElementsByTagName(tagName) : null;
        if (nodeList == null) {
            return new ArrayList<>();
        }
        final int length = nodeList.getLength();
        List<Element> result = new ArrayList<>(length);
        for (int i = 0; i < length; i++) {
            final Node item = nodeList.item(i);
            if (item.getNodeType() != Node.ELEMENT_NODE) {
                continue;
            }
            Element element = (Element) item;
            if (parentElement == element.getParentNode()) {
                result.add(element);
            }
        }
        return result;
    }

    /**
     * 遍历根据TagName查询的节点列表
     *
     * @param parentElement 父节点
     * @param tagName       查询名称
     * @param func          转换函数
     * @param <R>           泛型
     * @return List
     */
    public static <R> List<R> getChildren(Element parentElement, String tagName, Function<Element, R> func) {
        final List<Element> nodeList = getChildren(parentElement, tagName);
        List<R> result = new ArrayList<>();
        for (Element tempElement : nodeList) {
            result.add(func.apply(tempElement));
        }
        return result;
    }
}
