package org.fluentlenium.adapter;

import com.google.common.base.Supplier;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.experimental.Delegate;
import org.fluentlenium.configuration.ConfigurationProperties.DriverLifecycle;
import org.openqa.selenium.WebDriver;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/**
 * A singleton container for all running {@link SharedWebDriver} in the JVM.
 */
public enum SharedWebDriverContainer {
    INSTANCE;

    @Delegate
    private final Impl impl = new Impl();

    private final SharedWebDriverContainerShutdownHook shutdownHook; // NOPMD SingularField

    SharedWebDriverContainer() {
        shutdownHook = new SharedWebDriverContainerShutdownHook("SharedWebDriverContainerShutdownHook");
        Runtime.getRuntime().addShutdownHook(shutdownHook);
    }

    @EqualsAndHashCode
    @AllArgsConstructor
    private static class ClassAndTestName {
        private Class<?> testClass;
        private String testName;
    }

    static class Impl {
        private SharedWebDriver jvmDriver;

        private final Map<Class<?>, SharedWebDriver> classDrivers = new HashMap<>();

        private final Map<ClassAndTestName, SharedWebDriver> methodDrivers = new HashMap<>();

        /**
         * Get an existing or create a new driver for the given test, with the given shared driver
         * strategy.
         *
         * @param webDriverFactory Supplier supplying new WebDriver instances
         * @param testClass        Test class
         * @param testName         Test name
         * @param driverLifecycle  WebDriver lifecycle
         * @return
         */
        public <T> SharedWebDriver getOrCreateDriver(final Supplier<WebDriver> webDriverFactory, final Class<T> testClass,
                final String testName, final DriverLifecycle driverLifecycle) {
            synchronized (this) {
                SharedWebDriver driver = getDriver(testClass, testName, driverLifecycle);
                if (driver == null) {
                    driver = createDriver(webDriverFactory, testClass, testName, driverLifecycle);
                    registerDriver(driver);
                }
                return driver;
            }
        }

        private <T> SharedWebDriver createDriver(final Supplier<WebDriver> webDriverFactory, final Class<T> testClass,
                final String testName, final DriverLifecycle driverLifecycle) {
            final WebDriver webDriver = webDriverFactory.get();
            return new SharedWebDriver(webDriver, testClass, testName, driverLifecycle);
        }

        private void registerDriver(final SharedWebDriver driver) {
            switch (driver.getDriverLifecycle()) {
            case JVM:
                jvmDriver = driver;
                break;
            case CLASS:
                classDrivers.put(driver.getTestClass(), driver);
                break;
            case METHOD:
            default:
                methodDrivers.put(new ClassAndTestName(driver.getTestClass(), driver.getTestName()), driver);
                break;
            }
        }

        public <T> SharedWebDriver getDriver(final Class<T> testClass, final String testName,
                final DriverLifecycle driverLifecycle) {
            synchronized (this) {
                switch (driverLifecycle) {
                case JVM:
                    return jvmDriver;
                case CLASS:
                    return classDrivers.get(testClass);
                case METHOD:
                default:
                    return methodDrivers.get(new ClassAndTestName(testClass, testName));
                }
            }
        }

        public void quit(final SharedWebDriver driver) {
            synchronized (this) {
                switch (driver.getDriverLifecycle()) {
                case JVM:
                    if (jvmDriver == driver) { // NOPMD CompareObjectsWithEquals
                        if (jvmDriver.getDriver() != null) {
                            jvmDriver.getDriver().quit();
                        }
                        jvmDriver = null;
                    }
                    break;
                case CLASS:
                    final SharedWebDriver classDriver = classDrivers.remove(driver.getTestClass());
                    if (classDriver == driver && classDriver.getDriver() != null) { // NOPMD CompareObjectsWithEquals
                        classDriver.getDriver().quit();
                    }
                    break;
                case METHOD:
                default:
                    final SharedWebDriver testDriver = methodDrivers
                            .remove(new ClassAndTestName(driver.getTestClass(), driver.getTestName()));
                    if (testDriver == driver && testDriver.getDriver() != null) { // NOPMD CompareObjectsWithEquals
                        testDriver.getDriver().quit();
                    }
                    break;
                }
            }
        }

        /**
         * Get all WebDriver of this container.
         *
         * @return List of {@link SharedWebDriver}
         */
        public List<SharedWebDriver> getAllDrivers() {
            final List<SharedWebDriver> drivers = new ArrayList<>();
            synchronized (this) {
                if (jvmDriver != null) {
                    drivers.add(jvmDriver);
                }
                for (final SharedWebDriver classDriver : classDrivers.values()) {
                    drivers.add(classDriver);
                }

                for (final SharedWebDriver testDriver : methodDrivers.values()) {
                    drivers.add(testDriver);
                }
            }
            return Collections.unmodifiableList(drivers);
        }

        /**
         * Get all WebDriver of this container for given class.
         */
        public List<SharedWebDriver> getTestClassDrivers(final Class<?> testClass) {
            final List<SharedWebDriver> drivers = new ArrayList<>();

            synchronized (this) {
                final SharedWebDriver classDriver = classDrivers.get(testClass);
                if (classDriver != null) {
                    drivers.add(classDriver);
                }

                for (final SharedWebDriver testDriver : methodDrivers.values()) {
                    if (testDriver.getTestClass() == testClass) {
                        drivers.add(testDriver);
                    }
                }

                return Collections.unmodifiableList(drivers);
            }
        }

        public void quitAll() {
            synchronized (this) {
                if (jvmDriver != null) {
                    jvmDriver.getDriver().quit();
                    jvmDriver = null;
                }

                final Iterator<SharedWebDriver> classDriversIterator = classDrivers.values().iterator();
                while (classDriversIterator.hasNext()) {
                    classDriversIterator.next().getDriver().quit();
                    classDriversIterator.remove();
                }

                final Iterator<SharedWebDriver> testDriversIterator = methodDrivers.values().iterator();
                while (testDriversIterator.hasNext()) {
                    testDriversIterator.next().getDriver().quit();
                    testDriversIterator.remove();
                }
            }
        }
    }

}
