package org.neo4j.bolt.runtime;

import java.io.IOException;
import java.io.PrintWriter;
import java.net.SocketAddress;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.neo4j.bolt.dbapi.CustomBookmarkFormatParser;
import org.neo4j.bolt.dbapi.impl.BoltKernelDatabaseManagementServiceProvider;
import org.neo4j.bolt.negotiation.ProtocolVersion;
import org.neo4j.bolt.protocol.BoltProtocolRegistry;
import org.neo4j.bolt.protocol.common.BoltProtocol;
import org.neo4j.bolt.protocol.common.connector.connection.Connection;
import org.neo4j.bolt.protocol.common.connector.connection.authentication.AuthenticationFlag;
import org.neo4j.bolt.protocol.common.fsm.StateMachine;
import org.neo4j.bolt.protocol.v40.BoltProtocolV40;
import org.neo4j.bolt.protocol.v40.bookmark.BookmarkParserV40;
import org.neo4j.bolt.protocol.v41.BoltProtocolV41;
import org.neo4j.bolt.protocol.v43.BoltProtocolV43;
import org.neo4j.bolt.protocol.v44.BoltProtocolV44;
import org.neo4j.bolt.security.Authentication;
import org.neo4j.bolt.security.AuthenticationResult;
import org.neo4j.bolt.security.basic.BasicAuthentication;
import org.neo4j.bolt.security.error.AuthenticationException;
import org.neo4j.bolt.transaction.StatementProcessorTxManager;
import org.neo4j.common.DependencyResolver;
import org.neo4j.configuration.Config;
import org.neo4j.configuration.GraphDatabaseSettings;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.dbms.database.DatabaseContextProvider;
import org.neo4j.graphdb.config.Setting;
import org.neo4j.io.IOUtils;
import org.neo4j.kernel.api.security.AuthManager;
import org.neo4j.kernel.database.DatabaseIdRepository;
import org.neo4j.kernel.impl.query.clientconnection.BoltConnectionInfo;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.internal.LogService;
import org.neo4j.monitoring.Monitors;
import org.neo4j.server.security.systemgraph.CommunityDefaultDatabaseResolver;
import org.neo4j.storageengine.api.TransactionIdStore;
import org.neo4j.test.TestDatabaseManagementServiceBuilder;
import org.neo4j.time.Clocks;
import org.neo4j.time.SystemNanoClock;

/* loaded from: input_file:org/neo4j/bolt/runtime/SessionExtension.class */
public class SessionExtension implements BeforeEachCallback, AfterEachCallback {
    private final Supplier<TestDatabaseManagementServiceBuilder> builderFactory;
    private GraphDatabaseAPI gdb;
    private BoltProtocolRegistry protocolRegistry;
    private DatabaseManagementService managementService;
    private Authentication authentication;
    private final List<StateMachine> runningMachines;
    private boolean authEnabled;

    public SessionExtension() {
        this(TestDatabaseManagementServiceBuilder::new);
    }

    public SessionExtension(Supplier<TestDatabaseManagementServiceBuilder> supplier) {
        this.runningMachines = new ArrayList();
        this.builderFactory = supplier;
    }

    public StateMachine newMachine(ProtocolVersion protocolVersion) {
        assertTestStarted();
        StateMachine createStateMachine = ((BoltProtocol) this.protocolRegistry.get(protocolVersion).orElseThrow(() -> {
            return new IllegalArgumentException("Unsupported protocol version: " + protocolVersion);
        })).createStateMachine(createConnection());
        this.runningMachines.add(createStateMachine);
        return createStateMachine;
    }

    private Connection createConnection() {
        Connection connection = (Connection) Mockito.mock(Connection.class, Mockito.RETURNS_MOCKS);
        Mockito.when(connection.id()).thenReturn("bolt-test");
        Mockito.when(connection.selectedDefaultDatabase()).thenAnswer(invocationOnMock -> {
            return defaultDatabaseName();
        });
        AtomicInteger atomicInteger = new AtomicInteger();
        ((Connection) Mockito.doAnswer(invocationOnMock2 -> {
            atomicInteger.incrementAndGet();
            return null;
        }).when(connection)).interrupt();
        ((Connection) Mockito.doAnswer(invocationOnMock3 -> {
            int i;
            do {
                i = atomicInteger.get();
                if (i == 0) {
                    return true;
                }
            } while (atomicInteger.compareAndSet(i, i - 1));
            return Boolean.valueOf(i <= 1);
        }).when(connection)).reset();
        Mockito.when(Boolean.valueOf(connection.isInterrupted())).thenAnswer(invocationOnMock4 -> {
            return Boolean.valueOf(atomicInteger.get() != 0);
        });
        AtomicReference atomicReference = new AtomicReference();
        Mockito.when(connection.loginContext()).thenAnswer(invocationOnMock5 -> {
            return atomicReference.get();
        });
        try {
            Mockito.when(connection.authenticate((Map) ArgumentMatchers.any(), (String) ArgumentMatchers.any())).thenAnswer(invocationOnMock6 -> {
                AuthenticationResult authenticate = this.authentication.authenticate((Map) invocationOnMock6.getArgument(0), new BoltConnectionInfo("bolt-test", "bolt-test", (SocketAddress) Mockito.mock(SocketAddress.class), (SocketAddress) Mockito.mock(SocketAddress.class)));
                atomicReference.set(authenticate.getLoginContext());
                if (authenticate.credentialsExpired()) {
                    return AuthenticationFlag.CREDENTIALS_EXPIRED;
                }
                return null;
            });
        } catch (AuthenticationException e) {
        }
        return connection;
    }

    public DatabaseManagementService managementService() {
        assertTestStarted();
        return this.managementService;
    }

    public String defaultDatabaseName() {
        assertTestStarted();
        return (String) ((Config) this.gdb.getDependencyResolver().resolveDependency(Config.class)).get(GraphDatabaseSettings.initial_default_database);
    }

    public DatabaseIdRepository databaseIdRepository() {
        assertTestStarted();
        return ((DatabaseContextProvider) this.gdb.getDependencyResolver().resolveDependency(DatabaseContextProvider.class)).databaseIdRepository();
    }

    public void beforeEach(ExtensionContext extensionContext) {
        this.managementService = this.builderFactory.get().impermanent().setConfig((Setting<Setting>) GraphDatabaseSettings.auth_enabled, (Setting) Boolean.valueOf(this.authEnabled)).build();
        this.gdb = this.managementService.database("neo4j");
        DependencyResolver dependencyResolver = this.gdb.getDependencyResolver();
        Config config = (Config) dependencyResolver.resolveDependency(Config.class);
        LogService logService = (LogService) dependencyResolver.resolveDependency(LogService.class);
        DatabaseContextProvider databaseContextProvider = (DatabaseContextProvider) dependencyResolver.resolveDependency(DatabaseContextProvider.class);
        SystemNanoClock nanoClock = Clocks.nanoClock();
        StatementProcessorTxManager statementProcessorTxManager = new StatementProcessorTxManager();
        CommunityDefaultDatabaseResolver communityDefaultDatabaseResolver = new CommunityDefaultDatabaseResolver(config, () -> {
            return this.managementService.database("system");
        });
        BoltKernelDatabaseManagementServiceProvider boltKernelDatabaseManagementServiceProvider = new BoltKernelDatabaseManagementServiceProvider(this.managementService, new Monitors(), nanoClock, Duration.ofSeconds(30L));
        new BookmarkParserV40(databaseContextProvider.databaseIdRepository(), CustomBookmarkFormatParser.DEFAULT);
        this.protocolRegistry = BoltProtocolRegistry.builder().register(new BoltProtocolV40(logService, boltKernelDatabaseManagementServiceProvider, communityDefaultDatabaseResolver, statementProcessorTxManager, nanoClock)).register(new BoltProtocolV41(logService, boltKernelDatabaseManagementServiceProvider, communityDefaultDatabaseResolver, statementProcessorTxManager, nanoClock)).register(new BoltProtocolV43(logService, boltKernelDatabaseManagementServiceProvider, communityDefaultDatabaseResolver, statementProcessorTxManager, nanoClock)).register(new BoltProtocolV44(logService, boltKernelDatabaseManagementServiceProvider, communityDefaultDatabaseResolver, statementProcessorTxManager, nanoClock)).build();
        this.authentication = new BasicAuthentication((AuthManager) dependencyResolver.resolveDependency(AuthManager.class));
    }

    public void afterEach(ExtensionContext extensionContext) {
        try {
            IOUtils.closeAll(this.runningMachines);
            this.runningMachines.clear();
        } catch (Throwable th) {
            th.printStackTrace();
        }
        this.managementService.shutdown();
    }

    private void assertTestStarted() {
        if (this.protocolRegistry == null || this.gdb == null) {
            throw new IllegalStateException("Cannot access test environment before test is running.");
        }
    }

    private static Authentication authentication(AuthManager authManager) {
        return new BasicAuthentication(authManager);
    }

    public long lastClosedTxId() {
        return ((TransactionIdStore) this.gdb.getDependencyResolver().resolveDependency(TransactionIdStore.class)).getLastClosedTransactionId();
    }

    public static URL putTmpFile(String str, String str2, String str3) throws IOException {
        Path createTempFile = Files.createTempFile(str, str2, new FileAttribute[0]);
        createTempFile.toFile().deleteOnExit();
        PrintWriter printWriter = new PrintWriter(Files.newOutputStream(createTempFile, new OpenOption[0]), false, StandardCharsets.UTF_8);
        try {
            printWriter.println(str3);
            printWriter.close();
            return createTempFile.toUri().toURL();
        } catch (Throwable th) {
            try {
                printWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public SessionExtension withAuthEnabled(boolean z) {
        this.authEnabled = z;
        return this;
    }
}
