package com.feingto.cloud.data.jdbc;

import com.feingto.cloud.data.jdbc.model.Record;
import com.feingto.cloud.data.jdbc.model.Records;
import org.apache.commons.lang3.StringUtils;
import org.hibernate.dialect.*;

import javax.sql.DataSource;
import java.sql.*;
import java.util.*;

/**
 * 数据库操作工具
 *
 * @author longfei
 */
@SuppressWarnings("deprecation")
public class DBKit {
    private DataSource dataSource;

    public DBKit(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    /**
     * 从DataSoure中取出connection, 判断DB类型
     *
     * @param dataSource DataSource
     * @return String
     */
    public static String getDbType(DataSource dataSource) {
        String dbType;
        String jdbcUrl = getJdbcUrlFromDataSource(dataSource);

        if (jdbcUrl.contains("oracle")) {
            dbType = "oracle";
        } else if (jdbcUrl.contains("mysql")) {
            dbType = "mysql";
        } else if (jdbcUrl.contains("h2")) {
            dbType = "h2";
        } else if (jdbcUrl.contains("postgresql")) {
            dbType = "postgresql";
        } else if (jdbcUrl.contains("sqlserver")) {
            dbType = "sqlserver";
        } else {
            throw new IllegalArgumentException("Unknown Database of " + jdbcUrl);
        }

        return dbType;
    }

    /**
     * 从DataSoure中取出connection, 根据connection的metadata中的jdbcUrl判断Dialect类型.
     * 仅支持Oracle, H2, MySql, PostgreSql, SQLServer，如需更多数据库类型，请仿照此类自行编写。
     *
     * @param dataSource DataSource
     * @return String
     */
    public static String getDialect(DataSource dataSource) {
        String jdbcUrl = getJdbcUrlFromDataSource(dataSource);

        if (StringUtils.contains(jdbcUrl, ":h2:")) {
            return H2Dialect.class.getName();
        } else if (StringUtils.contains(jdbcUrl, ":mysql:")) {
            return MySQL5InnoDBDialect.class.getName();
        } else if (StringUtils.contains(jdbcUrl, ":oracle:")) {
            return Oracle10gDialect.class.getName();
        } else if (StringUtils.contains(jdbcUrl, ":postgresql:")) {
            return PostgreSQL82Dialect.class.getName();
        } else if (StringUtils.contains(jdbcUrl, ":sqlserver:")) {
            return SQLServer2008Dialect.class.getName();
        } else {
            throw new IllegalArgumentException("Unknown Database of " + jdbcUrl);
        }
    }

    private static String getJdbcUrlFromDataSource(DataSource dataSource) {
        Connection connection = null;

        try {
            connection = dataSource.getConnection();

            if (Objects.isNull(connection)) {
                throw new IllegalStateException("Connection returned by DataSource [" + dataSource + "] was null");
            }

            return connection.getMetaData().getURL();
        } catch (SQLException e) {
            throw new RuntimeException("Could not get database url", e);
        } finally {
            if (Objects.nonNull(connection)) {
                try {
                    connection.close();
                } catch (SQLException ignored) {
                }
            }
        }
    }

    /**
     * 获取数据库连接
     *
     * @return Connection
     * @throws SQLException
     */
    public Connection getConnection() throws SQLException {
        return dataSource.getConnection();
    }

    /**
     * 关闭resultSet
     *
     * @param resultSet  ResultSet
     * @param statement  Statement
     * @param connection Connection
     */
    private void close(ResultSet resultSet, Statement statement, Connection connection) {
        try {
            if (Objects.nonNull(resultSet)) {
                resultSet.close();
            }
            if (Objects.nonNull(statement)) {
                statement.close();
            }
            if (Objects.nonNull(connection)) {
                connection.close();
            }
        } catch (SQLException e) {
            throw new RuntimeException("SQL 异常: " + e.getMessage());
        }
    }

    private void setParams(PreparedStatement statement, Object... params) {
        if (Objects.isNull(params)) {
            return;
        }

        for (int i = 1; i <= params.length; i++) {
            try {
                statement.setObject(i, params[i - 1]);
            } catch (SQLException e) {
                throw new RuntimeException("SQL 异常: " + e.getMessage());
            }
        }
    }

    /**
     * 获取单行数据
     *
     * @param sql    Sql字符串
     * @param params 参数数组
     * @return Record
     */
    public Record get(String sql, Object... params) {
        Record record = new Record();
        Connection connection = null;
        PreparedStatement statement = null;
        ResultSet rs = null;

        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql);
            setParams(statement, params);
            rs = statement.executeQuery();

            if (rs.next()) {
                ResultSetMetaData metaData = rs.getMetaData();
                for (int i = 1; i <= metaData.getColumnCount(); i++) {
                    String label = metaData.getColumnLabel(i);
                    record.set(label, rs.getObject(label));
                }
            }
        } catch (SQLException e) {
            throw new RuntimeException("SQL 异常: " + e.getMessage());
        } finally {
            close(rs, statement, connection);
        }

        return record;
    }

    /**
     * 获取单行数据（Map）
     *
     * @param sql    Sql字符串
     * @param params 参数数组
     * @return Map
     */
    public Map<String, Object> getMap(String sql, Object... params) {
        Map<String, Object> map = new HashMap<>();
        Connection connection = null;
        PreparedStatement statement = null;
        ResultSet rs = null;

        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql);
            setParams(statement, params);
            rs = statement.executeQuery();

            if (rs.next()) {
                ResultSetMetaData metaData = rs.getMetaData();
                for (int i = 1; i <= metaData.getColumnCount(); i++) {
                    String label = metaData.getColumnLabel(i);
                    map.put(label, rs.getObject(label));
                }
            }
        } catch (SQLException e) {
            throw new RuntimeException("SQL 异常: " + e.getMessage());
        } finally {
            close(rs, statement, connection);
        }

        return map;
    }

    /**
     * 获取集合数据
     *
     * @param sql    Sql字符串
     * @param params 参数数组
     * @return Records
     */
    public Records list(String sql, Object... params) {
        Records records = new Records();
        Connection connection = null;
        PreparedStatement statement = null;
        ResultSet rs = null;

        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql);
            setParams(statement, params);
            rs = statement.executeQuery();

            while (rs.next()) {
                Record record = new Record();
                ResultSetMetaData metaData = rs.getMetaData();
                for (int i = 1; i <= metaData.getColumnCount(); i++) {
                    String label = metaData.getColumnLabel(i);
                    record.set(label, rs.getObject(label));
                }
                records.add(record);
            }
        } catch (SQLException e) {
            throw new RuntimeException("SQL 异常: " + e.getMessage());
        } finally {
            close(rs, statement, connection);
        }

        return records;
    }

    /**
     * 获取集合数据
     *
     * @param sql    Sql字符串
     * @param params 参数数组
     * @return List
     */
    public List<Map<String, Object>> listToMap(String sql, Object... params) {
        List<Map<String, Object>> list = new ArrayList<>();
        Connection connection = null;
        PreparedStatement statement = null;
        ResultSet rs = null;

        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql);
            setParams(statement, params);
            rs = statement.executeQuery();

            while (rs.next()) {
                Map<String, Object> map = new LinkedHashMap<>();
                ResultSetMetaData metaData = rs.getMetaData();
                for (int i = 1; i <= metaData.getColumnCount(); i++) {
                    String label = metaData.getColumnLabel(i);
                    map.put(label, rs.getObject(label));
                }
                list.add(map);
            }
        } catch (SQLException e) {
            throw new RuntimeException("SQL 异常: " + e.getMessage());
        } finally {
            close(rs, statement, connection);
        }

        return list;
    }

    /**
     * 执行更新
     *
     * @param sql    Sql字符串
     * @param params 参数数组
     * @return int
     */
    public int execute(String sql, Object... params) {
        int result;
        Connection connection = null;
        PreparedStatement statement = null;

        try {
            connection = getConnection();
            statement = connection.prepareStatement(sql);
            setParams(statement, params);
            result = statement.executeUpdate();
        } catch (SQLException e) {
            throw new RuntimeException("SQL 异常: " + e.getMessage());
        } finally {
            close(null, statement, connection);
        }

        return result;
    }
}
