package com.graphql.spring.boot.test;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.CloseReason;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import lombok.NonNull;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment;
import org.springframework.http.ResponseEntity;
import org.springframework.lang.Nullable;
import org.springframework.util.ResourceUtils;
import org.springframework.web.util.DefaultUriBuilderFactory;
import org.springframework.web.util.UriBuilderFactory;

/* loaded from: input_file:com/graphql/spring/boot/test/GraphQLTestSubscription.class */
public class GraphQLTestSubscription {
    private static final int SLEEP_INTERVAL_MS = 100;
    private static final int ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT = 60000;
    private final Environment environment;
    private final ObjectMapper objectMapper;
    private final String subscriptionPath;
    private Session session;
    private SubscriptionState state = SubscriptionState.builder().id(ID_COUNTER.incrementAndGet()).build();
    private static final Logger log = LoggerFactory.getLogger(GraphQLTestSubscription.class);
    private static final WebSocketContainer WEB_SOCKET_CONTAINER = ContainerProvider.getWebSocketContainer();
    private static final AtomicInteger ID_COUNTER = new AtomicInteger(1);
    private static final UriBuilderFactory URI_BUILDER_FACTORY = new DefaultUriBuilderFactory();
    private static final Object STATE_LOCK = new Object();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/graphql/spring/boot/test/GraphQLTestSubscription$TestMessageHandler.class */
    public static class TestMessageHandler implements MessageHandler.Whole<String> {
        private final ObjectMapper objectMapper;
        private final SubscriptionState state;

        public void onMessage(String str) {
            try {
                GraphQLTestSubscription.log.debug("Received message from web socket: {}", str);
                JsonNode readTree = this.objectMapper.readTree(str);
                JsonNode jsonNode = readTree.get("type");
                Assertions.assertThat(jsonNode).as("GraphQL messages should have a type field.", new Object[0]).isNotNull();
                ((AbstractBooleanAssert) Assertions.assertThat(jsonNode.isNull()).as("GraphQL messages type should not be null.", new Object[0])).isFalse();
                String asText = jsonNode.asText();
                if (asText.equals("complete")) {
                    this.state.setCompleted(true);
                    GraphQLTestSubscription.log.debug("Subscription completed.");
                } else if (asText.equals("connection_ack")) {
                    this.state.setAcknowledged(true);
                    GraphQLTestSubscription.log.debug("WebSocket connection acknowledged by the GraphQL Server.");
                } else if (asText.equals("data") || asText.equals("error")) {
                    JsonNode jsonNode2 = readTree.get("payload");
                    Assertions.assertThat(jsonNode2).as("Data/error messages must have a payload.", new Object[0]).isNotNull();
                    GraphQLResponse graphQLResponse = new GraphQLResponse(ResponseEntity.ok(this.objectMapper.writeValueAsString(jsonNode2)), this.objectMapper);
                    if (this.state.isStopped() || this.state.isCompleted()) {
                        GraphQLTestSubscription.log.debug("Response discarded because subscription was stopped or completed in the meanwhile.");
                    } else {
                        synchronized (GraphQLTestSubscription.STATE_LOCK) {
                            this.state.getResponses().add(graphQLResponse);
                        }
                        GraphQLTestSubscription.log.debug("New response recorded.");
                    }
                }
            } catch (JsonProcessingException e) {
                org.junit.jupiter.api.Assertions.fail("Exception while parsing server response. Response is not a valid GraphQL response.", e);
            }
        }

        public TestMessageHandler(ObjectMapper objectMapper, SubscriptionState subscriptionState) {
            this.objectMapper = objectMapper;
            this.state = subscriptionState;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/graphql/spring/boot/test/GraphQLTestSubscription$TestWebSocketClient.class */
    public static class TestWebSocketClient extends Endpoint {
        private final SubscriptionState state;

        public void onOpen(Session session, EndpointConfig endpointConfig) {
            GraphQLTestSubscription.log.debug("Connection established.");
        }

        public void onClose(Session session, CloseReason closeReason) {
            super.onClose(session, closeReason);
            this.state.setStopped(true);
        }

        public TestWebSocketClient(SubscriptionState subscriptionState) {
            this.state = subscriptionState;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/graphql/spring/boot/test/GraphQLTestSubscription$TestWebSocketClientConfigurator.class */
    public static class TestWebSocketClientConfigurator extends ClientEndpointConfig.Configurator {
        TestWebSocketClientConfigurator() {
        }

        public void beforeRequest(Map<String, List<String>> map) {
            super.beforeRequest(map);
            map.put("sec-websocket-protocol", Collections.singletonList("graphql-ws"));
        }
    }

    public boolean isInitialized() {
        return this.state.isInitialized();
    }

    public boolean isAcknowledged() {
        return this.state.isAcknowledged();
    }

    public boolean isStarted() {
        return this.state.isStarted();
    }

    public boolean isStopped() {
        return this.state.isStopped();
    }

    public boolean isCompleted() {
        return this.state.isCompleted();
    }

    public GraphQLTestSubscription init() {
        init(null);
        return this;
    }

    public GraphQLTestSubscription init(@Nullable Object obj) {
        if (isInitialized()) {
            org.junit.jupiter.api.Assertions.fail("Subscription already initialized.");
        }
        try {
            initClient();
        } catch (Exception e) {
            org.junit.jupiter.api.Assertions.fail("Could not initialize test subscription client. No subscription defined?", e);
        }
        ObjectNode createObjectNode = this.objectMapper.createObjectNode();
        createObjectNode.put("type", "connection_init");
        createObjectNode.set("payload", getFinalPayload(obj));
        sendMessage(createObjectNode);
        this.state.setInitialized(true);
        awaitAcknowledgement();
        log.debug("Subscription successfully initialized.");
        return this;
    }

    public GraphQLTestSubscription start(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("graphQLResource is marked non-null but is null");
        }
        start(str, null);
        return this;
    }

    public GraphQLTestSubscription start(@NonNull String str, @Nullable Object obj) {
        if (str == null) {
            throw new NullPointerException("graphGLResource is marked non-null but is null");
        }
        if (!isInitialized()) {
            init();
        }
        if (isStarted()) {
            org.junit.jupiter.api.Assertions.fail("Start message already sent. To start a new subscription, please call reset first.");
        }
        this.state.setStarted(true);
        ObjectNode createObjectNode = this.objectMapper.createObjectNode();
        createObjectNode.put("query", loadQuery(str));
        createObjectNode.set("variables", getFinalPayload(obj));
        ObjectNode createObjectNode2 = this.objectMapper.createObjectNode();
        createObjectNode2.put("type", "start");
        createObjectNode2.put("id", this.state.getId());
        createObjectNode2.set("payload", createObjectNode);
        log.debug("Sending start message.");
        sendMessage(createObjectNode2);
        return this;
    }

    public GraphQLTestSubscription stop() {
        if (!isInitialized()) {
            org.junit.jupiter.api.Assertions.fail("Subscription not yet initialized.");
        }
        if (isStopped()) {
            org.junit.jupiter.api.Assertions.fail("Subscription already stopped.");
        }
        ObjectNode createObjectNode = this.objectMapper.createObjectNode();
        createObjectNode.put("type", "stop");
        createObjectNode.put("id", this.state.getId());
        log.debug("Sending stop message.");
        sendMessage(createObjectNode);
        try {
            log.debug("Closing web socket session.");
            this.session.close();
            awaitStop();
            log.debug("Web socket session closed.");
        } catch (IOException e) {
            org.junit.jupiter.api.Assertions.fail("Could not close web socket session", e);
        }
        return this;
    }

    public void reset() {
        if (isInitialized() && !isStopped()) {
            stop();
        }
        this.state = SubscriptionState.builder().id(ID_COUNTER.incrementAndGet()).build();
        this.session = null;
        log.debug("Test subscription client reset.");
    }

    public GraphQLResponse awaitAndGetNextResponse(int i) {
        return awaitAndGetNextResponses(i, 1, true).get(0);
    }

    public GraphQLResponse awaitAndGetNextResponse(int i, boolean z) {
        return awaitAndGetNextResponses(i, 1, z).get(0);
    }

    public List<GraphQLResponse> awaitAndGetAllResponses(int i) {
        return awaitAndGetNextResponses(i, -1, true);
    }

    public List<GraphQLResponse> awaitAndGetAllResponses(int i, boolean z) {
        return awaitAndGetNextResponses(i, -1, z);
    }

    public List<GraphQLResponse> awaitAndGetNextResponses(int i, int i2) {
        return awaitAndGetNextResponses(i, i2, true);
    }

    public List<GraphQLResponse> awaitAndGetNextResponses(int i, int i2, boolean z) {
        ArrayList arrayList;
        if (!isStarted()) {
            org.junit.jupiter.api.Assertions.fail("Start message not sent. Please send start message first.");
        }
        if (isStopped()) {
            org.junit.jupiter.api.Assertions.fail("Subscription already stopped. Forgot to call reset after test case?");
        }
        int i3 = 0;
        while (true) {
            if ((this.state.getResponses().size() < i2 || i2 <= 0) && i3 < i) {
                try {
                    Thread.sleep(100L);
                    i3 += SLEEP_INTERVAL_MS;
                } catch (InterruptedException e) {
                    org.junit.jupiter.api.Assertions.fail("Test execution error - Thread.sleep failed.", e);
                }
            }
        }
        if (z) {
            stop();
        }
        synchronized (STATE_LOCK) {
            Queue<GraphQLResponse> responses = this.state.getResponses();
            int size = responses.size();
            if (i2 == 0) {
                Assertions.assertThat(responses).as(String.format("Expected no responses in %s MS, but received %s", Integer.valueOf(i), Integer.valueOf(responses.size())), new Object[0]).isEmpty();
            }
            if (i2 > 0) {
                Assertions.assertThat(responses).as("Expected at least %s message(s) in %d MS, but %d received.", new Object[]{Integer.valueOf(i2), Integer.valueOf(i), Integer.valueOf(responses.size())}).hasSizeGreaterThanOrEqualTo(i2);
                size = i2;
            }
            arrayList = new ArrayList();
            for (int i4 = 0; i4 < size; i4++) {
                arrayList.add(responses.poll());
            }
            log.debug("Returning {} responses.", Integer.valueOf(arrayList.size()));
        }
        return arrayList;
    }

    public GraphQLTestSubscription waitAndExpectNoResponse(int i, boolean z) {
        awaitAndGetNextResponses(i, 0, z);
        return this;
    }

    public GraphQLTestSubscription waitAndExpectNoResponse(int i) {
        awaitAndGetNextResponses(i, 0, true);
        return this;
    }

    public List<GraphQLResponse> getRemainingResponses() {
        if (!isStopped()) {
            org.junit.jupiter.api.Assertions.fail("getRemainingResponses should only be called after the subscription was stopped.");
        }
        ArrayList arrayList = new ArrayList(this.state.getResponses());
        this.state.getResponses().clear();
        return arrayList;
    }

    private void initClient() throws Exception {
        URI build = URI_BUILDER_FACTORY.builder().scheme("ws").host("localhost").port(this.environment.getProperty("local.server.port")).path(this.subscriptionPath).build(new Object[0]);
        log.debug("Connecting to client at {}", build);
        ClientEndpointConfig build2 = ClientEndpointConfig.Builder.create().configurator(new TestWebSocketClientConfigurator()).build();
        build2.getUserProperties().put("org.apache.tomcat.websocket.IO_TIMEOUT_MS", String.valueOf(ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT));
        this.session = WEB_SOCKET_CONTAINER.connectToServer(new TestWebSocketClient(this.state), build2, build);
        this.session.addMessageHandler(new TestMessageHandler(this.objectMapper, this.state));
    }

    private JsonNode getFinalPayload(Object obj) {
        Optional ofNullable = Optional.ofNullable(obj);
        ObjectMapper objectMapper = this.objectMapper;
        objectMapper.getClass();
        Optional map = ofNullable.map(objectMapper::valueToTree);
        ObjectMapper objectMapper2 = this.objectMapper;
        objectMapper2.getClass();
        return (JsonNode) map.orElseGet(objectMapper2::createObjectNode);
    }

    private String loadQuery(String str) {
        try {
            return new String(Files.readAllBytes(ResourceUtils.getFile("classpath:" + str).toPath()), StandardCharsets.UTF_8);
        } catch (IOException e) {
            org.junit.jupiter.api.Assertions.fail(String.format("Test setup failure - could not load GraphQL resource: %s", str), e);
            return "";
        }
    }

    private void sendMessage(Object obj) {
        try {
            this.session.getBasicRemote().sendText(this.objectMapper.writeValueAsString(obj));
        } catch (IOException e) {
            org.junit.jupiter.api.Assertions.fail("Test setup failure - cannot serialize subscription payload.", e);
        }
    }

    private void awaitAcknowledgement() {
        await((v0) -> {
            return v0.isAcknowledged();
        }, "Connection was not acknowledged by the GraphQL server.");
    }

    private void awaitStop() {
        await((v0) -> {
            return v0.isStopped();
        }, "Connection was not stopped in time.");
    }

    private void await(Predicate<GraphQLTestSubscription> predicate, String str) {
        int i = 0;
        while (!predicate.test(this) && i < ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT) {
            try {
                Thread.sleep(100L);
                i += SLEEP_INTERVAL_MS;
            } catch (InterruptedException e) {
                org.junit.jupiter.api.Assertions.fail("Test execution error - Thread.sleep failed.", e);
            }
        }
        if (predicate.test(this)) {
            return;
        }
        org.junit.jupiter.api.Assertions.fail("Timeout: " + str);
    }

    public GraphQLTestSubscription(Environment environment, ObjectMapper objectMapper, String str) {
        this.environment = environment;
        this.objectMapper = objectMapper;
        this.subscriptionPath = str;
    }

    public Session getSession() {
        return this.session;
    }
}
