package cn.taketoday.web.socket.server.support;

import cn.taketoday.http.HttpMethod;
import cn.taketoday.http.HttpStatus;
import cn.taketoday.lang.Assert;
import cn.taketoday.lang.Nullable;
import cn.taketoday.lang.TodayStrategies;
import cn.taketoday.logging.Logger;
import cn.taketoday.logging.LoggerFactory;
import cn.taketoday.util.ClassUtils;
import cn.taketoday.util.LogFormatUtils;
import cn.taketoday.util.StringUtils;
import cn.taketoday.web.RequestContext;
import cn.taketoday.web.socket.SubProtocolCapable;
import cn.taketoday.web.socket.WebSocketExtension;
import cn.taketoday.web.socket.WebSocketHandler;
import cn.taketoday.web.socket.WebSocketHttpHeaders;
import cn.taketoday.web.socket.WebSocketSession;
import cn.taketoday.web.socket.server.HandshakeFailureException;
import cn.taketoday.web.socket.server.HandshakeHandler;
import cn.taketoday.web.socket.server.RequestUpgradeStrategy;
import cn.taketoday.web.socket.server.jetty.JettyRequestUpgradeStrategy;
import cn.taketoday.web.socket.server.standard.StandardWebSocketUpgradeStrategy;
import cn.taketoday.web.socket.server.standard.TomcatRequestUpgradeStrategy;
import cn.taketoday.web.socket.server.standard.UndertowRequestUpgradeStrategy;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

/* loaded from: input_file:cn/taketoday/web/socket/server/support/AbstractHandshakeHandler.class */
public abstract class AbstractHandshakeHandler implements HandshakeHandler {
    private static final boolean tomcatWsPresent = ClassUtils.isPresent("org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", AbstractHandshakeHandler.class);
    private static final boolean jettyWsPresent = ClassUtils.isPresent("org.eclipse.jetty.ee10.websocket.server.JettyWebSocketServerContainer", AbstractHandshakeHandler.class);
    private static final boolean undertowWsPresent = ClassUtils.isPresent("io.undertow.websockets.jsr.ServerWebSocketContainer", AbstractHandshakeHandler.class);
    protected final Logger logger;
    private final RequestUpgradeStrategy requestUpgradeStrategy;
    private final ArrayList<String> supportedProtocols;

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractHandshakeHandler() {
        this(initRequestUpgradeStrategy());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractHandshakeHandler(RequestUpgradeStrategy requestUpgradeStrategy) {
        this.logger = LoggerFactory.getLogger(getClass());
        this.supportedProtocols = new ArrayList<>();
        Assert.notNull(requestUpgradeStrategy, "RequestUpgradeStrategy is required");
        this.requestUpgradeStrategy = requestUpgradeStrategy;
    }

    public RequestUpgradeStrategy getRequestUpgradeStrategy() {
        return this.requestUpgradeStrategy;
    }

    public void setSupportedProtocols(String... strArr) {
        this.supportedProtocols.clear();
        for (String str : strArr) {
            this.supportedProtocols.add(str.toLowerCase());
        }
    }

    public String[] getSupportedProtocols() {
        return StringUtils.toStringArray(this.supportedProtocols);
    }

    @Override // cn.taketoday.web.socket.server.HandshakeHandler
    public final WebSocketSession doHandshake(RequestContext requestContext, WebSocketHandler webSocketHandler, Map<String, Object> map) throws HandshakeFailureException {
        WebSocketHttpHeaders webSocketHttpHeaders = new WebSocketHttpHeaders(requestContext.getHeaders());
        if (this.logger.isTraceEnabled()) {
            this.logger.trace("Processing request {} with headers={}", requestContext.getURI(), webSocketHttpHeaders);
        }
        try {
            if (HttpMethod.GET != requestContext.getMethod()) {
                requestContext.setStatus(HttpStatus.METHOD_NOT_ALLOWED);
                requestContext.responseHeaders().setAllow(Collections.singleton(HttpMethod.GET));
                if (!this.logger.isErrorEnabled()) {
                    return null;
                }
                this.logger.error("Handshake failed due to unexpected HTTP method: {}", requestContext.getMethod());
                return null;
            }
            if (!"WebSocket".equalsIgnoreCase(webSocketHttpHeaders.getUpgrade())) {
                handleInvalidUpgradeHeader(requestContext);
                return null;
            }
            List connection = webSocketHttpHeaders.getConnection();
            if (!connection.contains("Upgrade") && !connection.contains("upgrade")) {
                handleInvalidConnectHeader(requestContext);
                return null;
            }
            if (!isWebSocketVersionSupported(webSocketHttpHeaders)) {
                handleWebSocketVersionNotSupported(requestContext);
                return null;
            }
            if (!isValidOrigin(requestContext)) {
                requestContext.setStatus(HttpStatus.FORBIDDEN);
                return null;
            }
            if (webSocketHttpHeaders.getSecWebSocketKey() == null) {
                if (this.logger.isErrorEnabled()) {
                    this.logger.error("Missing \"Sec-WebSocket-Key\" header");
                }
                requestContext.setStatus(HttpStatus.BAD_REQUEST);
                return null;
            }
            String selectProtocol = selectProtocol(webSocketHttpHeaders.getSecWebSocketProtocol(), webSocketHandler);
            List<WebSocketExtension> filterRequestedExtensions = filterRequestedExtensions(requestContext, webSocketHttpHeaders.getSecWebSocketExtensions(), this.requestUpgradeStrategy.getSupportedExtensions(requestContext));
            if (this.logger.isTraceEnabled()) {
                this.logger.trace("Upgrading to WebSocket, subProtocol={}, extensions={}", selectProtocol, filterRequestedExtensions);
            }
            return this.requestUpgradeStrategy.upgrade(requestContext, selectProtocol, filterRequestedExtensions, webSocketHandler, map);
        } catch (IOException e) {
            throw new HandshakeFailureException("Response update failed during upgrade to WebSocket: " + requestContext.getURI(), e);
        }
    }

    protected void handleInvalidUpgradeHeader(RequestContext requestContext) throws IOException {
        if (this.logger.isErrorEnabled()) {
            this.logger.error(LogFormatUtils.formatValue("Handshake failed due to invalid Upgrade header: " + requestContext.getHeaders().getUpgrade(), -1, true));
        }
        requestContext.setStatus(HttpStatus.BAD_REQUEST);
        requestContext.getOutputStream().write("Can \"Upgrade\" only to \"WebSocket\".".getBytes(StandardCharsets.UTF_8));
    }

    protected void handleInvalidConnectHeader(RequestContext requestContext) throws IOException {
        if (this.logger.isErrorEnabled()) {
            this.logger.error(LogFormatUtils.formatValue("Handshake failed due to invalid Connection header" + requestContext.getHeaders().getConnection(), -1, true));
        }
        requestContext.setStatus(HttpStatus.BAD_REQUEST);
        requestContext.getOutputStream().write("\"Connection\" must be \"upgrade\".".getBytes(StandardCharsets.UTF_8));
    }

    protected boolean isWebSocketVersionSupported(WebSocketHttpHeaders webSocketHttpHeaders) {
        String secWebSocketVersion = webSocketHttpHeaders.getSecWebSocketVersion();
        for (String str : getSupportedVersions()) {
            if (str.trim().equals(secWebSocketVersion)) {
                return true;
            }
        }
        return false;
    }

    protected String[] getSupportedVersions() {
        return this.requestUpgradeStrategy.getSupportedVersions();
    }

    protected void handleWebSocketVersionNotSupported(RequestContext requestContext) {
        if (this.logger.isErrorEnabled()) {
            this.logger.error(LogFormatUtils.formatValue("Handshake failed due to unsupported WebSocket version: " + requestContext.getHeaders().getFirst("Sec-WebSocket-Version") + ". Supported versions: " + Arrays.toString(getSupportedVersions()), -1, true));
        }
        requestContext.setStatus(HttpStatus.UPGRADE_REQUIRED);
        requestContext.setHeader("Sec-WebSocket-Version", StringUtils.arrayToCommaDelimitedString(getSupportedVersions()));
    }

    protected boolean isValidOrigin(RequestContext requestContext) {
        return true;
    }

    @Nullable
    protected String selectProtocol(List<String> list, WebSocketHandler webSocketHandler) {
        List<String> determineHandlerSupportedProtocols = determineHandlerSupportedProtocols(webSocketHandler);
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            String next = it.next();
            if (!determineHandlerSupportedProtocols.contains(next.toLowerCase()) && !this.supportedProtocols.contains(next.toLowerCase())) {
            }
            return next;
        }
        return null;
    }

    protected final List<String> determineHandlerSupportedProtocols(WebSocketHandler webSocketHandler) {
        Object rawHandler = webSocketHandler.getRawHandler();
        List<String> list = null;
        if (rawHandler instanceof SubProtocolCapable) {
            list = ((SubProtocolCapable) rawHandler).getSubProtocols();
        }
        return list != null ? list : Collections.emptyList();
    }

    protected List<WebSocketExtension> filterRequestedExtensions(RequestContext requestContext, List<WebSocketExtension> list, List<WebSocketExtension> list2) {
        ArrayList arrayList = new ArrayList(list.size());
        for (WebSocketExtension webSocketExtension : list) {
            if (list2.contains(webSocketExtension)) {
                arrayList.add(webSocketExtension);
            }
        }
        return arrayList;
    }

    private static RequestUpgradeStrategy initRequestUpgradeStrategy() {
        RequestUpgradeStrategy requestUpgradeStrategy = (RequestUpgradeStrategy) TodayStrategies.findFirst(RequestUpgradeStrategy.class, (Supplier) null);
        return requestUpgradeStrategy != null ? requestUpgradeStrategy : tomcatWsPresent ? new TomcatRequestUpgradeStrategy() : jettyWsPresent ? new JettyRequestUpgradeStrategy() : undertowWsPresent ? new UndertowRequestUpgradeStrategy() : new StandardWebSocketUpgradeStrategy();
    }
}
