package cn.wjee.commons.thread;

import cn.wjee.commons.WJeeVar;
import cn.wjee.commons.collection.CollectionUtils;
import cn.wjee.commons.collection.StepWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntPredicate;

/**
 * Thread Utils
 *
 * @author wjee
 * @version $Id: ThreadUtils.java, v 0.1 2016年2月14日 下午9:09:27 wjee Exp $
 */
public class ThreadUtils {
    /**
     * 日志
     */
    private static final Logger log = LoggerFactory.getLogger(ThreadUtils.class);
    /**
     * 监控任务线程池
     */
    private static final ScheduledThreadPoolExecutor MONITOR_POOL = getMonitorPool();
    /**
     * Future超时任务线程池
     */
    private static final ScheduledThreadPoolExecutor FAIL_FAST_POOL = getFailFastPool();
    /**
     * 其他默认线程池
     */
    public static final ThreadPoolExecutor THREAD_POOL_WJEE = getDefaultPool();

    private ThreadUtils() {
    }

    /**
     * 线程池监控任务
     *
     * @return ScheduledThreadPoolExecutor
     */
    private static ScheduledThreadPoolExecutor getMonitorPool() {
        MyThreadFactory myThreadFactory = new MyThreadFactory("monitor-pool");
        myThreadFactory.setDaemon(true);
        ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(1, myThreadFactory);
        scheduledThreadPoolExecutor.setRemoveOnCancelPolicy(true);
        Runtime.getRuntime().addShutdownHook(new Thread(scheduledThreadPoolExecutor::shutdown));
        return scheduledThreadPoolExecutor;
    }

    /**
     * 快速失败线程池，用来做CompletableFuture超时
     *
     * @return ScheduledThreadPoolExecutor
     */
    private static ScheduledThreadPoolExecutor getFailFastPool() {
        MyThreadFactory myThreadFactory = new MyThreadFactory("fail-fast");
        myThreadFactory.setDaemon(true);
        ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(1, myThreadFactory);
        scheduledThreadPoolExecutor.setRemoveOnCancelPolicy(true);
        Runtime.getRuntime().addShutdownHook(new Thread(scheduledThreadPoolExecutor::shutdown));
        return scheduledThreadPoolExecutor;
    }

    /**
     * 初始默认线程池
     *
     * @return ThreadPoolExecutor
     */
    private static ThreadPoolExecutor getDefaultPool() {
        final String poolName = WJeeVar.NAMESPACE + "-";
        ThreadPoolExecutor simplePool = MyThreadPoolExecutor.getSimplePool(poolName);
        MONITOR_POOL.scheduleAtFixedRate(() -> log.info("ThreadPool-{}status -> core: {}, max: {}, queue:{}",
            poolName,
            simplePool.getCorePoolSize(),
            simplePool.getMaximumPoolSize(),
            simplePool.getQueue().size()
        ), 30, 30, TimeUnit.SECONDS);
        return simplePool;
    }

    /**
     * 睡眠秒
     *
     * @param mills 毫秒数
     */
    public static void sleep(Long mills) {
        try {
            Thread.sleep(mills);
        } catch (Exception e) {
            Thread.currentThread().interrupt();
            log.error("Thread Sleep Error", e);
        }
    }

    /**
     * 线程池执行线程
     *
     * @param runnable 线程任务
     */
    public static void execute(Runnable runnable) {
        THREAD_POOL_WJEE.execute(runnable);
    }

    public static void execute(Runnable... runnable) {
        Arrays.stream(runnable).forEach(THREAD_POOL_WJEE::execute);
    }

    /**
     * 异步执行function方法，单位时间内执行不完，就返回null
     *
     * @param t              function请求参数
     * @param timeoutSeconds 超时时间(秒)
     * @param function       业务函数回调
     * @param <T>            参数泛型
     * @param <R>            出参泛型
     * @return OUT
     */
    public static <T, R> R execute(T t, Long timeoutSeconds, Function<T, R> function) {
        try {
            CompletableFuture<R> future = CompletableFuture
                .supplyAsync(() -> function.apply(t), THREAD_POOL_WJEE)
                .exceptionally(throwable -> null);
            return future.get(timeoutSeconds, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            return null;
        } catch (ExecutionException | TimeoutException e) {
            return null;
        }
    }

    /**
     * 并行请求测试
     *
     * @param concurrent 并发
     * @param rounds     轮数
     * @param consumer   回调
     */
    public static void concurrentTest(int concurrent, int rounds, IntPredicate consumer) {
        final ExecutorService fixedThreadPool = Executors.newFixedThreadPool(concurrent);
        Runtime.getRuntime().addShutdownHook(new Thread(fixedThreadPool::shutdown));

        for (int roundIndex = 0; roundIndex < rounds; roundIndex++) {
            final StringBuilder buffer = new StringBuilder();
            long start = System.currentTimeMillis();
            CountDownLatch latch = new CountDownLatch(concurrent);
            final Integer roundIndexFinal = roundIndex + 1;
            final Map<Integer, StepWatch.StepTraceInfo> taskDetailMap = new HashMap<>(9);

            for (int i = 0; i < concurrent; i++) {
                final Integer concurrentIndexFinal = i + 1;
                fixedThreadPool.submit(() -> {
                    long tempStartTime = System.currentTimeMillis();
                    boolean isSuccess = consumer.test(concurrentIndexFinal);
                    long tempEndTime = System.currentTimeMillis();
                    taskDetailMap.put(concurrentIndexFinal, new StepWatch.StepTraceInfo(concurrentIndexFinal + "", (tempEndTime - tempStartTime), isSuccess));
                    latch.countDown();
                });
            }

            try {
                latch.await();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                log.error("latch.await fail", e);
            }

            long end = System.currentTimeMillis();
            long successCount = taskDetailMap.values().stream().filter(StepWatch.StepTraceInfo::isSuccess).count();
            long minCost = taskDetailMap.values().stream().filter(StepWatch.StepTraceInfo::isSuccess).mapToLong(StepWatch.StepTraceInfo::getCostTime).min().orElse(0L);
            long maxCost = taskDetailMap.values().stream().filter(StepWatch.StepTraceInfo::isSuccess).mapToLong(StepWatch.StepTraceInfo::getCostTime).max().orElse(0L);
            long failCount = concurrent - successCount;
            BigDecimal successRate = new BigDecimal(successCount + "").divide(new BigDecimal(concurrent + ""), 5, RoundingMode.HALF_UP);
            java.lang.String successRateFormat = new DecimalFormat(" #.00%").format(successRate);
            buffer.append("\n")
                .append("---------------------------------------------\n")
                .append("第").append(roundIndexFinal).append("轮, ")
                .append(concurrent).append("并发, ")
                .append("失败/成功[").append(failCount).append("/").append(successCount).append("],")
                .append("成功率:").append(successRateFormat).append("\n")
                .append("总耗时:").append(end - start).append("ms, ")
                .append("最小请求耗时:").append(minCost).append("ms, ")
                .append("最大请求耗时:").append(maxCost).append("ms ")
                .append("\n")
                .append("---------------------------------------------\n");

            taskDetailMap.forEach((key, value) ->
                buffer.append("任务[").append(key).append("] ")
                    .append("=> 结果[").append(value.isSuccess() ? "成功" : "失败").append("], ")
                    .append("耗时: ").append(value.getCostTime()).append("ms \n")
            );

            String s = buffer.toString();
            log.info(s);
        }
    }

    /**
     * 批量提交任务，主线程等待执行
     *
     * @param executor   执行线程池
     * @param collection 待执行集合
     * @param consumer   消费回调
     * @param <T>        泛型
     */
    public static <T> void concurrentRun(ThreadPoolExecutor executor, Collection<T> collection, Consumer<T> consumer) {
        if (CollectionUtils.isEmpty(collection) || executor == null) {
            return;
        }
        final CountDownLatch latch = new CountDownLatch(collection.size());
        for (T next : collection) {
            executor.submit(() -> {
                try {
                    if (consumer != null) {
                        consumer.accept(next);
                    }
                } catch (Exception e) {
                    log.error("ThreadUtils concurrentRun fail", e);
                } finally {
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        } catch (Exception e) {
            Thread.currentThread().interrupt();
            log.error("concurrentRun await fail", e);
        }
    }

    /**
     * 批量提交任务，主线程等待执行
     *
     * @param executor    执行线程池
     * @param collections 待执行集合
     * @param consumer    消费回调
     * @param <T>         泛型
     */
    public static <T> void concurrentRunBatch(ThreadPoolExecutor executor, Collection<Collection<T>> collections, Consumer<Collection<T>> consumer) {
        if (CollectionUtils.isEmpty(collections) || executor == null) {
            return;
        }
        final CountDownLatch latch = new CountDownLatch(collections.size());
        for (Collection<T> next : collections) {
            executor.submit(() -> {
                try {
                    if (consumer != null) {
                        consumer.accept(next);
                    }
                } catch (Exception e) {
                    log.error("ThreadUtils concurrentRunBatch fail", e);
                } finally {
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        } catch (Exception e) {
            Thread.currentThread().interrupt();
            log.error("concurrentRunBatch await fail", e);
        }
    }

    /**
     * 带超时设置的Feature
     * <pre>
     * 使用方法：
     *
     * Optionals
     *     .within(3, TimeUnit.SECONDS, CompletableFuture.supplyAsync(() -&gt; {
     *         ThreadUtils.sleep(10_000);
     *         System.err.println("run...");
     *         return "abc";
     *     }))
     *     .exceptionally(throwable -&gt; {
     *         System.err.println("timeout...");
     *         return "timeout";
     *     });
     *
     * </pre>
     *
     * @param feature Future任务
     * @param time    超时时间
     * @param unit    超时时间单位
     * @param <T>     泛型
     * @return CompletableFuture
     */
    public static <T> CompletableFuture<T> within(Integer time, TimeUnit unit, CompletableFuture<T> feature) {
        CompletableFuture<T> timeoutFeature = new CompletableFuture<>();
        FAIL_FAST_POOL.schedule(() -> timeoutFeature.completeExceptionally(new TimeoutException()), time, unit);
        return feature.applyToEither(timeoutFeature, Function.identity());
    }
}
