package io.opentelemetry.testing.internal.armeria.server.websocket;

import io.opentelemetry.testing.internal.armeria.common.HttpData;
import io.opentelemetry.testing.internal.armeria.common.HttpHeaderNames;
import io.opentelemetry.testing.internal.armeria.common.HttpObject;
import io.opentelemetry.testing.internal.armeria.common.HttpRequest;
import io.opentelemetry.testing.internal.armeria.common.HttpResponse;
import io.opentelemetry.testing.internal.armeria.common.HttpStatus;
import io.opentelemetry.testing.internal.armeria.common.MediaType;
import io.opentelemetry.testing.internal.armeria.common.RequestHeaders;
import io.opentelemetry.testing.internal.armeria.common.ResponseHeaders;
import io.opentelemetry.testing.internal.armeria.common.ResponseHeadersBuilder;
import io.opentelemetry.testing.internal.armeria.common.SessionProtocol;
import io.opentelemetry.testing.internal.armeria.common.annotation.Nullable;
import io.opentelemetry.testing.internal.armeria.common.annotation.UnstableApi;
import io.opentelemetry.testing.internal.armeria.common.stream.ClosedStreamException;
import io.opentelemetry.testing.internal.armeria.common.stream.HttpDecoder;
import io.opentelemetry.testing.internal.armeria.common.stream.StreamMessage;
import io.opentelemetry.testing.internal.armeria.common.websocket.WebSocket;
import io.opentelemetry.testing.internal.armeria.internal.common.websocket.WebSocketFrameEncoder;
import io.opentelemetry.testing.internal.armeria.internal.common.websocket.WebSocketUtil;
import io.opentelemetry.testing.internal.armeria.internal.common.websocket.WebSocketWrapper;
import io.opentelemetry.testing.internal.armeria.internal.shaded.guava.base.Splitter;
import io.opentelemetry.testing.internal.armeria.internal.shaded.guava.net.HostAndPort;
import io.opentelemetry.testing.internal.armeria.server.AbstractHttpService;
import io.opentelemetry.testing.internal.armeria.server.ServiceRequestContext;
import io.opentelemetry.testing.internal.io.netty.handler.codec.http.HttpHeaderValues;
import io.opentelemetry.testing.internal.io.netty.handler.codec.http.websocketx.WebSocketVersion;
import java.util.Set;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@UnstableApi
/* loaded from: input_file:io/opentelemetry/testing/internal/armeria/server/websocket/WebSocketService.class */
public final class WebSocketService extends AbstractHttpService {
    private static final Logger logger;
    private static final String SUB_PROTOCOL_WILDCARD = "*";
    private static final ResponseHeaders UNSUPPORTED_WEB_SOCKET_VERSION;
    private static final Splitter commaSplitter;
    private static final WebSocketFrameEncoder encoder;
    private final WebSocketServiceHandler handler;
    private final int maxFramePayloadLength;
    private final boolean allowMaskMismatch;
    private final Set<String> subprotocols;
    private final Set<String> allowedOrigins;
    private final boolean allowAnyOrigin;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static WebSocketService of(WebSocketServiceHandler webSocketServiceHandler) {
        return new WebSocketServiceBuilder(webSocketServiceHandler).build();
    }

    public static WebSocketServiceBuilder builder(WebSocketServiceHandler webSocketServiceHandler) {
        return new WebSocketServiceBuilder(webSocketServiceHandler);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public WebSocketService(WebSocketServiceHandler webSocketServiceHandler, int i, boolean z, Set<String> set, Set<String> set2, boolean z2) {
        this.handler = webSocketServiceHandler;
        this.maxFramePayloadLength = i;
        this.allowMaskMismatch = z;
        this.subprotocols = set;
        this.allowedOrigins = set2;
        this.allowAnyOrigin = z2;
    }

    @Override // io.opentelemetry.testing.internal.armeria.server.AbstractHttpService
    protected HttpResponse doGet(ServiceRequestContext serviceRequestContext, HttpRequest httpRequest) throws Exception {
        if (!serviceRequestContext.sessionProtocol().isExplicitHttp1()) {
            return HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED);
        }
        RequestHeaders headers = httpRequest.headers();
        if (!WebSocketUtil.isHttp1WebSocketUpgradeRequest(headers)) {
            return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "The upgrade header must contain:\n  Upgrade: websocket\n  Connection: Upgrade");
        }
        HttpResponse checkOrigin = checkOrigin(serviceRequestContext, headers);
        if (checkOrigin != null) {
            return checkOrigin;
        }
        HttpResponse checkVersion = checkVersion(headers);
        if (checkVersion != null) {
            return checkVersion;
        }
        String str = headers.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, "");
        if (str.isEmpty()) {
            return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "missing Sec-WebSocket-Key header");
        }
        ResponseHeadersBuilder add = ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS).add((CharSequence) HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()).add((CharSequence) HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()).add((CharSequence) HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, WebSocketUtil.generateSecWebSocketAccept(str));
        maybeAddSubprotocol(headers, add);
        return handleUpgradeRequest(serviceRequestContext, httpRequest, add.build());
    }

    private void maybeAddSubprotocol(RequestHeaders requestHeaders, ResponseHeadersBuilder responseHeadersBuilder) {
        String str = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "");
        if (str.isEmpty()) {
            return;
        }
        commaSplitter.splitToStream(str).filter(str2 -> {
            return "*".equals(str2) || this.subprotocols.contains(str2);
        }).findFirst().ifPresent(str3 -> {
            responseHeadersBuilder.add((CharSequence) HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, str3);
        });
    }

    private HttpResponse handleUpgradeRequest(ServiceRequestContext serviceRequestContext, HttpRequest httpRequest, ResponseHeaders responseHeaders) {
        WebSocketServiceFrameDecoder webSocketServiceFrameDecoder = new WebSocketServiceFrameDecoder(serviceRequestContext, this.maxFramePayloadLength, this.allowMaskMismatch);
        WebSocket handle = this.handler.handle(serviceRequestContext, new WebSocketWrapper(httpRequest.decode((HttpDecoder) webSocketServiceFrameDecoder, serviceRequestContext.alloc())));
        webSocketServiceFrameDecoder.setOutboundWebSocket(handle);
        return HttpResponse.of(responseHeaders, (Publisher<? extends HttpObject>) handle.recoverAndResume(th -> {
            if (th instanceof ClosedStreamException) {
                return StreamMessage.aborted(th);
            }
            serviceRequestContext.logBuilder().responseCause(th);
            return StreamMessage.of(WebSocketUtil.newCloseWebSocketFrame(th));
        }).map(webSocketFrame -> {
            return HttpData.wrap(encoder.encode(serviceRequestContext, webSocketFrame));
        }));
    }

    @Override // io.opentelemetry.testing.internal.armeria.server.AbstractHttpService
    protected HttpResponse doConnect(ServiceRequestContext serviceRequestContext, HttpRequest httpRequest) throws Exception {
        if (!serviceRequestContext.sessionProtocol().isExplicitHttp2()) {
            return HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED);
        }
        RequestHeaders headers = httpRequest.headers();
        if (!WebSocketUtil.isHttp2WebSocketUpgradeRequest(headers)) {
            logger.trace("RequestHeaders does not contain headers for WebSocket upgrade. headers: {}", headers);
            return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "The upgrade header must contain:\n  :protocol = websocket");
        }
        HttpResponse checkOrigin = checkOrigin(serviceRequestContext, headers);
        if (checkOrigin != null) {
            return checkOrigin;
        }
        HttpResponse checkVersion = checkVersion(headers);
        if (checkVersion != null) {
            return checkVersion;
        }
        ResponseHeadersBuilder builder = ResponseHeaders.builder(HttpStatus.OK);
        maybeAddSubprotocol(headers, builder);
        return handleUpgradeRequest(serviceRequestContext, httpRequest, builder.build());
    }

    @Nullable
    private HttpResponse checkOrigin(ServiceRequestContext serviceRequestContext, RequestHeaders requestHeaders) {
        if (this.allowAnyOrigin) {
            return null;
        }
        String str = requestHeaders.get(HttpHeaderNames.ORIGIN, "");
        if (str.isEmpty()) {
            return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "missing the origin header");
        }
        if (this.allowedOrigins.isEmpty()) {
            if (isSameOrigin(serviceRequestContext, requestHeaders, str)) {
                return null;
            }
            return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "not allowed origin: " + str);
        }
        if (this.allowedOrigins.contains(str)) {
            return null;
        }
        return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "not allowed origin: " + str + ", allowed: " + this.allowedOrigins);
    }

    private static boolean isSameOrigin(ServiceRequestContext serviceRequestContext, RequestHeaders requestHeaders, String str) {
        SessionProtocol find;
        int indexOf = str.indexOf("://");
        if (indexOf < 0 || (find = SessionProtocol.find(str.substring(0, indexOf))) == null) {
            return false;
        }
        if ((!serviceRequestContext.sessionProtocol().isHttp() || !find.isHttp()) && (!serviceRequestContext.sessionProtocol().isHttps() || !find.isHttps())) {
            return false;
        }
        String authority = requestHeaders.authority();
        if (!$assertionsDisabled && authority == null) {
            throw new AssertionError();
        }
        HostAndPort fromString = HostAndPort.fromString(authority);
        String host = fromString.getHost();
        int portOrDefault = fromString.getPortOrDefault(serviceRequestContext.sessionProtocol().defaultPort());
        HostAndPort fromString2 = HostAndPort.fromString(str.substring(indexOf + 3));
        return portOrDefault == fromString2.getPortOrDefault(find.defaultPort()) && host.equals(fromString2.getHost());
    }

    @Nullable
    private static HttpResponse checkVersion(RequestHeaders requestHeaders) {
        if (WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION))) {
            return null;
        }
        return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION, HttpData.ofUtf8("Only 13 version is supported."));
    }

    static {
        $assertionsDisabled = !WebSocketService.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(WebSocketService.class);
        UNSUPPORTED_WEB_SOCKET_VERSION = ResponseHeaders.builder(HttpStatus.BAD_REQUEST).add((CharSequence) HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue()).contentType(MediaType.PLAIN_TEXT_UTF_8).build();
        commaSplitter = Splitter.on(',').trimResults().omitEmptyStrings();
        encoder = WebSocketFrameEncoder.of(false);
    }
}
