package com.aire.ux.test.spring.servlet;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import javax.servlet.Servlet;
import javax.servlet.annotation.WebServlet;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.junit.jupiter.SpringExtension;

/* loaded from: input_file:com/aire/ux/test/spring/servlet/ServletDefinitionExtension.class */
public class ServletDefinitionExtension implements Extension, BeforeAllCallback, AfterAllCallback {
    private static final String KEY = "DEFINITION_STORE";

    public void afterAll(ExtensionContext extensionContext) throws Exception {
        unregisterServletDefinitions(SpringExtension.getApplicationContext(extensionContext), extensionContext);
    }

    public void beforeAll(ExtensionContext extensionContext) throws Exception {
        registerServletDefinitions(SpringExtension.getApplicationContext(extensionContext), extensionContext);
    }

    private void registerServletDefinitions(ApplicationContext applicationContext, ExtensionContext extensionContext) {
        ExtensionContext.Store store = extensionContext.getStore(ExtensionContext.Namespace.create(new Object[]{applicationContext, extensionContext}));
        registerClient(store, applicationContext);
        extensionContext.getTestClass().flatMap(cls -> {
            return Optional.ofNullable((WithServlets) cls.getAnnotation(WithServlets.class));
        }).ifPresent(withServlets -> {
            defineServlets(store, withServlets, (ConfigurableListableBeanFactory) applicationContext.getAutowireCapableBeanFactory());
        });
        postProcessBeanFactory(store, (ConfigurableListableBeanFactory) applicationContext.getAutowireCapableBeanFactory());
    }

    @SuppressFBWarnings
    private void registerClient(ExtensionContext.Store store, ApplicationContext applicationContext) {
        BeanDefinitionRegistry autowireCapableBeanFactory = applicationContext.getAutowireCapableBeanFactory();
        AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(DefaultClient.class).addConstructorArgValue(applicationContext).getBeanDefinition();
        ((List) store.getOrComputeIfAbsent(KEY, str -> {
            return new ArrayList();
        })).add(beanDefinition);
        autowireCapableBeanFactory.registerBeanDefinition((String) Objects.requireNonNull(beanDefinition.getBeanClassName()), beanDefinition);
    }

    private void postProcessBeanFactory(ExtensionContext.Store store, ConfigurableListableBeanFactory configurableListableBeanFactory) throws BeansException {
        if (!(configurableListableBeanFactory instanceof BeanDefinitionRegistry)) {
            throw new IllegalStateException(String.format("Wrong type of bean factory (unsupported context): %s", configurableListableBeanFactory.getClass()));
        }
        for (String str : configurableListableBeanFactory.getBeanNamesForAnnotation(WithServlets.class)) {
            scanBeanType(store, configurableListableBeanFactory.getBeanDefinition(str), configurableListableBeanFactory);
        }
    }

    private void scanBeanType(ExtensionContext.Store store, BeanDefinition beanDefinition, ConfigurableListableBeanFactory configurableListableBeanFactory) {
        try {
            defineServlets(store, (WithServlets) Class.forName(beanDefinition.getBeanClassName(), false, configurableListableBeanFactory.getBeanClassLoader()).getAnnotation(WithServlets.class), configurableListableBeanFactory);
        } catch (ClassNotFoundException e) {
            throw new BeanInitializationException(e.getMessage(), e);
        }
    }

    private void defineServlets(ExtensionContext.Store store, WithServlets withServlets, ConfigurableListableBeanFactory configurableListableBeanFactory) {
        for (ServletDefinition servletDefinition : withServlets.servlets()) {
            if (!Servlet.class.equals(servletDefinition.type())) {
                defineServlet(store, servletDefinition.type(), configurableListableBeanFactory, servletDefinition.paths());
            }
        }
        for (Class<? extends Servlet> cls : withServlets.value()) {
            if (!Servlet.class.equals(cls)) {
                defineServlet(store, cls, configurableListableBeanFactory, getRequestMappings(cls));
            }
        }
    }

    @SuppressFBWarnings
    private void defineServlet(ExtensionContext.Store store, Class<? extends Servlet> cls, ConfigurableListableBeanFactory configurableListableBeanFactory, String[] strArr) {
        List list = (List) store.getOrComputeIfAbsent(KEY, str -> {
            return new ArrayList();
        });
        AbstractBeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(cls).getBeanDefinition();
        ((BeanDefinitionRegistry) configurableListableBeanFactory).registerBeanDefinition((String) Objects.requireNonNull(beanDefinition.getBeanClassName()), beanDefinition);
        AbstractBeanDefinition beanDefinition2 = BeanDefinitionBuilder.rootBeanDefinition(ServletRegistrationBean.class).addConstructorArgReference(beanDefinition.getBeanClassName()).addConstructorArgValue(true).addConstructorArgValue(strArr).setLazyInit(false).getBeanDefinition();
        ((BeanDefinitionRegistry) configurableListableBeanFactory).registerBeanDefinition(beanDefinition.getBeanClassName() + "registration", beanDefinition2);
        list.addAll(List.of(beanDefinition, beanDefinition2));
    }

    private String[] getRequestMappings(Class<? extends Servlet> cls) {
        if (cls.isAnnotationPresent(WebServlet.class)) {
            return cls.getAnnotation(WebServlet.class).value();
        }
        throw new UnsupportedOperationException(String.format("Error: must annotate '%s' with an @WebServlet containing request mappings", cls));
    }

    private void unregisterServletDefinitions(ApplicationContext applicationContext, ExtensionContext extensionContext) {
        HashSet hashSet = new HashSet((List) extensionContext.getStore(ExtensionContext.Namespace.create(new Object[]{applicationContext, extensionContext})).getOrComputeIfAbsent(KEY, str -> {
            return new ArrayList();
        }));
        BeanDefinitionRegistry autowireCapableBeanFactory = applicationContext.getAutowireCapableBeanFactory();
        for (String str2 : autowireCapableBeanFactory.getBeanDefinitionNames()) {
            if (hashSet.contains(autowireCapableBeanFactory.getBeanDefinition(str2))) {
                autowireCapableBeanFactory.removeBeanDefinition(str2);
            }
        }
    }
}
