/*
 * Decompiled with CFR 0.152.
 */
package org.hibernate.ogm.utils;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import javax.transaction.TransactionManager;
import org.hibernate.SessionFactory;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.engine.transaction.jta.platform.spi.JtaPlatform;
import org.hibernate.ogm.OgmSessionFactory;
import org.hibernate.ogm.exception.impl.Exceptions;
import org.hibernate.ogm.util.impl.Log;
import org.hibernate.ogm.util.impl.LoggerFactory;
import org.hibernate.ogm.utils.SkippableTestRunner;
import org.hibernate.ogm.utils.TestEntities;
import org.hibernate.ogm.utils.TestHelper;
import org.hibernate.ogm.utils.TestSessionFactory;
import org.hibernate.ogm.utils.TestSessionFactoryConfiguration;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.model.FrameworkField;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InitializationError;
import org.junit.runners.model.TestClass;

public class OgmTestRunner
extends SkippableTestRunner {
    private static final Log LOG = LoggerFactory.make();
    private final Set<Field> testScopedFactoryFields = OgmTestRunner.getTestFactoryFields(this.getTestClass(), TestSessionFactory.Scope.TEST_CLASS);
    private final Set<Field> testMethodScopedFactoryFields = OgmTestRunner.getTestFactoryFields(this.getTestClass(), TestSessionFactory.Scope.TEST_METHOD);
    private SessionFactory testScopedSessionFactory;
    private SessionFactory testMethodScopedSessionFactory;

    public OgmTestRunner(Class<?> klass) throws InitializationError {
        super(klass);
    }

    private static Set<Field> getTestFactoryFields(TestClass testClass, TestSessionFactory.Scope scope) {
        HashSet<Field> testFactoryFields = new HashSet<Field>();
        for (FrameworkField frameworkField : testClass.getAnnotatedFields(TestSessionFactory.class)) {
            Field field = frameworkField.getField();
            if (scope != field.getAnnotation(TestSessionFactory.class).scope()) continue;
            field.setAccessible(true);
            testFactoryFields.add(field);
        }
        return testFactoryFields;
    }

    @Override
    public void run(RunNotifier notifier) {
        if (this.isTestScopedSessionFactoryRequired()) {
            this.testScopedSessionFactory = this.buildSessionFactory();
            this.injectSessionFactory(null, this.testScopedFactoryFields, this.testScopedSessionFactory);
        }
        try {
            super.run(notifier);
        }
        finally {
            if (this.testScopedSessionFactory != null) {
                this.cleanUpPendingTransactionIfRequired();
                TestHelper.dropSchemaAndDatabase(this.testScopedSessionFactory);
                this.testScopedSessionFactory.close();
            }
        }
    }

    @Override
    protected void runChild(FrameworkMethod method, RunNotifier notifier) {
        if (this.isTestMethodScopedSessionFactoryRequired(method)) {
            this.testMethodScopedSessionFactory = this.buildSessionFactory();
        }
        try {
            super.runChild(method, notifier);
        }
        finally {
            if (this.testMethodScopedSessionFactory != null) {
                this.cleanUpPendingTransactionIfRequired();
                TestHelper.dropSchemaAndDatabase(this.testScopedSessionFactory);
                this.testMethodScopedSessionFactory.close();
            }
        }
    }

    private boolean isTestScopedSessionFactoryRequired() {
        return !this.isTestClassSkipped() && !this.areAllTestMethodsSkipped();
    }

    private boolean isTestMethodScopedSessionFactoryRequired(FrameworkMethod method) {
        return !this.testMethodScopedFactoryFields.isEmpty() && !super.isTestMethodSkipped(method);
    }

    private void cleanUpPendingTransactionIfRequired() {
        TransactionManager transactionManager = ((JtaPlatform)((SessionFactoryImplementor)this.testScopedSessionFactory).getServiceRegistry().getService(JtaPlatform.class)).retrieveTransactionManager();
        try {
            if (transactionManager != null && transactionManager.getTransaction() != null) {
                LOG.warn((Object)"The test started a transaction but failed to commit it or roll it back. Going to roll it back.");
                transactionManager.rollback();
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected OgmSessionFactory buildSessionFactory() {
        return TestHelper.getDefaultTestSessionFactory(this.getTestSpecificSettings(), this.getConfiguredEntityTypes());
    }

    private Class<?>[] getConfiguredEntityTypes() {
        Iterator iterator = this.getTestClass().getAnnotatedMethods(TestEntities.class).iterator();
        if (iterator.hasNext()) {
            FrameworkMethod frameworkMethod = (FrameworkMethod)iterator.next();
            Class<?>[] entityTypes = this.invokeTestEntitiesMethod(frameworkMethod);
            if (entityTypes == null || entityTypes.length == 0) {
                throw new IllegalArgumentException("Define at least a single annotated entity");
            }
            return entityTypes;
        }
        throw new IllegalStateException("The entities of the test must be retrievable via a parameterless method which is annotated with " + TestEntities.class.getSimpleName() + " and returns Class<?>[].");
    }

    private Class<?>[] invokeTestEntitiesMethod(FrameworkMethod frameworkMethod) {
        Method method = frameworkMethod.getMethod();
        method.setAccessible(true);
        if (method.getReturnType() != Class[].class || method.getParameterTypes().length > 0) {
            throw new IllegalStateException("Method annotated with " + TestEntities.class.getSimpleName() + " must have no parameters and must return Class<?>[].");
        }
        Class[] entityTypes = null;
        try {
            entityTypes = (Class[])method.invoke(super.createTest(), new Object[0]);
        }
        catch (Exception e) {
            Exceptions.sneakyThrow((Exception)e);
        }
        return entityTypes;
    }

    private Map<String, Object> getTestSpecificSettings() {
        HashMap<String, Object> testSpecificSettings = new HashMap<String, Object>();
        try {
            for (FrameworkMethod frameworkMethod : this.getTestClass().getAnnotatedMethods(TestSessionFactoryConfiguration.class)) {
                Method method = frameworkMethod.getMethod();
                method.setAccessible(true);
                method.invoke(super.createTest(), testSpecificSettings);
            }
        }
        catch (Exception e) {
            Exceptions.sneakyThrow((Exception)e);
        }
        return testSpecificSettings;
    }

    protected Object createTest() throws Exception {
        Object test = super.createTest();
        if (!this.testScopedFactoryFields.isEmpty()) {
            this.injectSessionFactory(test, this.testScopedFactoryFields, this.testScopedSessionFactory);
        }
        if (!this.testMethodScopedFactoryFields.isEmpty()) {
            this.injectSessionFactory(test, this.testMethodScopedFactoryFields, this.testMethodScopedSessionFactory);
        }
        return test;
    }

    private void injectSessionFactory(Object test, Iterable<Field> fields, SessionFactory sessionFactory) {
        for (Field field : fields) {
            try {
                if ((test != null || !Modifier.isStatic(field.getModifiers())) && (test == null || Modifier.isStatic(field.getModifiers()))) continue;
                field.set(test, sessionFactory);
            }
            catch (Exception e) {
                throw new RuntimeException("Can't inject session factory into field " + field);
            }
        }
    }
}

