package com.groupbyinc.common.util;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;

public final class ThreadUtils {

  private static final transient Logger LOG = LoggerFactory.getLogger(ThreadUtils.class);

  private static final int MAX_PARALLELISM = Math.max(1, Runtime.getRuntime().availableProcessors() - 1);

  private ThreadUtils() {
    // not publicly instantiable
  }

  public static void sleep(long millis) {
    try {
      Thread.sleep(millis);
    } catch (InterruptedException ie) {
      LOG.warn(ie.getMessage(), ie);
    }
  }

  public static ThreadFactory defaultThreadFactory(String name) {
    return defaultThreadFactory(name, false);
  }

  public static ThreadFactory defaultThreadFactory(String name, boolean includeThreadNumber) {
    return defaultThreadFactory(name, includeThreadNumber, true);
  }

  public static ThreadFactory defaultThreadFactory(String name, boolean includeThreadNumber, boolean includeParentThreadName) {
    return new DefaultThreadFactory(name, includeThreadNumber, includeParentThreadName);
  }

  public static ForkJoinPool newForkJoinPool(String name, boolean includeThreadName) {
    return newForkJoinPool(MAX_PARALLELISM, name, includeThreadName);
  }

  public static ForkJoinPool newForkJoinPool(int parallelism, String name, boolean includeThreadName) {
    ForkJoinPool.ForkJoinWorkerThreadFactory factory = new DefaultForkJoinWorkerThreadFactory(Thread.currentThread().getName(), name, includeThreadName);
    return new ForkJoinPool(Math.max(1, parallelism), factory, null, false);
  }

  public static ExecutorService newFixedPool(int parallelism, String name, boolean includeThreadName) {
    return newFixedPool(parallelism, Integer.MAX_VALUE, name, includeThreadName);
  }

  public static ExecutorService newFixedPool(int parallelism, int queueSize, String name, boolean includeThreadName) {
    return newFixedPool(parallelism, queueSize, name, includeThreadName, true);
  }

  public static ExecutorService newFixedPool(int parallelism, int queueSize, String name, boolean includeThreadName, boolean includeParentName) {
    ThreadFactory factory = defaultThreadFactory(name, includeThreadName, includeParentName);
    return new ThreadPoolExecutor(parallelism, parallelism, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(queueSize), factory);
  }

  public static ExecutorService newBlockingFixedPool(int parallelism, String name, boolean includeThreadName) {
    return newBlockingFixedPool(parallelism, 0, name, includeThreadName);
  }

  /**
   * SUPER IMPORTANT NOTE:
   *
   * Only use this if the following conditions are met:
   *    - You only have one thread producing tasks
   *    - You absolutely need to block on the number of threads running the tasks.
   *
   * This can be a very bad idea for the following reasons:
   *    - Adding via getQueue() is strongly discouraged by the API, and may be prohibited at some point.
   *
   * This only works because the executor has corePoolSize == maxPoolSize, DO NOT CHANGE THAT
   */
  public static ExecutorService newBlockingFixedPool(int parallelism, int queueSize, String name, boolean includeThreadName) {
    BlockingQueue<Runnable> queue = queueSize < 1 ? new SynchronousQueue<>() : new LinkedBlockingQueue<>(queueSize);
    ThreadFactory factory = defaultThreadFactory(name, includeThreadName);
    return new ThreadPoolExecutor(parallelism, parallelism, 0L, TimeUnit.MILLISECONDS, queue, factory, (runnable, executor) -> {
      try {
        executor.getQueue().put(runnable);
      } catch (InterruptedException e) {
        // do nothing
      }
    });
  }

  public static void async(ForkJoinPool workers, Runnable runnable, Function<Exception, String> errorMessage) {
    async(workers, runnable, null, errorMessage);
  }

  public static void async(
      ForkJoinPool workers, Runnable runnable, Supplier<String> successMessage, Function<Exception, String> errorMessage) {
    ForkJoinTask<?> task = workers.submit(runnable);
    try {
      task.get();
      if (successMessage != null) {
        LOG.info("{}", successMessage.get());
      }
    } catch (InterruptedException | ExecutionException e) {
      LOG.error(errorMessage.apply(e), e);
    }
  }
}
