package org.snf4j.websocket.handshake;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.snf4j.core.SSLSession;
import org.snf4j.core.codec.ICodecPipeline;
import org.snf4j.core.session.ISession;
import org.snf4j.websocket.IWebSocketSessionConfig;
import org.snf4j.websocket.extensions.IExtension;
import org.snf4j.websocket.extensions.InvalidExtensionException;

/* loaded from: input_file:org/snf4j/websocket/handshake/Handshaker.class */
public class Handshaker implements IHandshaker {
    private static final IExtension[] EMPTY = new IExtension[0];
    private final boolean clientMode;
    private final IWebSocketSessionConfig config;
    private volatile boolean finished;
    private boolean closing;
    private String key;
    private volatile String subProtocol;
    private volatile URI uri;
    private ISession session;
    private String cause;
    private volatile List<IExtension> extensions;

    public Handshaker(IWebSocketSessionConfig iWebSocketSessionConfig, boolean z) {
        this.config = iWebSocketSessionConfig;
        this.clientMode = z;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public void setSession(ISession iSession) {
        this.session = iSession;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public boolean isClientMode() {
        return this.clientMode;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public boolean isFinished() {
        return this.finished;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public boolean isClosing() {
        return this.closing;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public String getSubProtocol() {
        return this.subProtocol;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public boolean hasExtensions() {
        return this.extensions != null;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public IExtension getExtension(String str) {
        if (this.extensions == null) {
            return null;
        }
        for (IExtension iExtension : this.extensions) {
            if (iExtension.getName().equals(str)) {
                return iExtension;
            }
        }
        return null;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public IExtension[] getExtensions() {
        return this.extensions == null ? EMPTY : (IExtension[]) this.extensions.toArray(new IExtension[this.extensions.size()]);
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public String[] getExtensionNames() {
        List<IExtension> list = this.extensions;
        if (list == null) {
            return new String[0];
        }
        Object[] array = list.toArray();
        String[] strArr = new String[array.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = ((IExtension) array[i]).getName();
        }
        return strArr;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public void updateExtensionEncoders(ICodecPipeline iCodecPipeline) {
        Iterator<IExtension> it = this.extensions.iterator();
        while (it.hasNext()) {
            it.next().updateEncoders(iCodecPipeline);
        }
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public void updateExtensionDecoders(ICodecPipeline iCodecPipeline) {
        Iterator<IExtension> it = this.extensions.iterator();
        while (it.hasNext()) {
            it.next().updateDecoders(iCodecPipeline);
        }
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public URI getUri() {
        return this.uri;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public String getClosingReason() {
        return this.cause;
    }

    void cause(String str) {
        this.cause = str;
    }

    HandshakeRequest request() {
        String[] supportedSubProtocols = this.config.getSupportedSubProtocols();
        IExtension[] supportedExtensions = this.config.getSupportedExtensions();
        String requestOrigin = this.config.getRequestOrigin();
        this.uri = this.config.getRequestUri();
        HandshakeRequest handshakeRequest = new HandshakeRequest(HandshakeUtils.requestUri(this.uri));
        handshakeRequest.addValue("Host", HandshakeUtils.host(this.uri));
        handshakeRequest.addValue("Upgrade", "websocket");
        handshakeRequest.addValue("Connection", "Upgrade");
        this.key = HandshakeUtils.generateKey();
        handshakeRequest.addValue("Sec-WebSocket-Key", this.key);
        if (requestOrigin != null) {
            handshakeRequest.addValue("Origin", requestOrigin);
        }
        handshakeRequest.addValue("Sec-WebSocket-Version", 13);
        if (supportedSubProtocols != null && supportedSubProtocols.length > 0) {
            handshakeRequest.addValue("Sec-WebSocket-Protocol", HttpUtils.values(supportedSubProtocols));
        }
        if (supportedExtensions != null && supportedExtensions.length > 0) {
            String[] strArr = new String[supportedExtensions.length];
            for (int i = 0; i < supportedExtensions.length; i++) {
                strArr[i] = HandshakeUtils.extension(supportedExtensions[i].offer());
            }
            handshakeRequest.addValue("Sec-WebSocket-Extensions", HttpUtils.values(strArr));
        }
        this.config.customizeHeaders(handshakeRequest);
        return handshakeRequest;
    }

    void acceptVersion(HandshakeRequest handshakeRequest) throws HandshakeAcceptException {
        String value = handshakeRequest.getValue("Sec-WebSocket-Version");
        if (value == null) {
            cause("Missing websocket version");
            throw new HandshakeAcceptException(HttpStatus.BAD_REQUEST);
        }
        for (String str : HttpUtils.values(value)) {
            try {
                if (Integer.parseInt(str) == 13) {
                    return;
                }
            } catch (Exception e) {
                cause("Incorrect websocket version: " + str);
                throw new HandshakeAcceptException(HttpStatus.BAD_REQUEST);
            }
        }
        HandshakeResponse handshakeResponse = new HandshakeResponse(HttpStatus.UPGRADE_REQUIRED);
        handshakeResponse.addValue("Sec-WebSocket-Version", 13);
        cause("Unsupported websocket version: " + value);
        throw new HandshakeAcceptException(handshakeResponse);
    }

    void acceptBasicFields(HandshakeRequest handshakeRequest) throws HandshakeAcceptException {
        if (!validateBasicFields(handshakeRequest)) {
            throw new HandshakeAcceptException(HttpStatus.BAD_REQUEST);
        }
    }

    byte[] acceptKey(HandshakeRequest handshakeRequest) throws HandshakeAcceptException {
        String value = handshakeRequest.getValue("Sec-WebSocket-Key");
        if (value != null) {
            byte[] parseKey = HandshakeUtils.parseKey(value);
            if (parseKey != null) {
                return parseKey;
            }
            cause("Invalid websocket key: " + value);
        } else {
            cause("Missing websocket key");
        }
        throw new HandshakeAcceptException(HttpStatus.BAD_REQUEST);
    }

    void acceptSubProtocol(HandshakeRequest handshakeRequest) {
        String[] supportedSubProtocols;
        String value = handshakeRequest.getValue("Sec-WebSocket-Protocol");
        if (value == null || value.isEmpty() || (supportedSubProtocols = this.config.getSupportedSubProtocols()) == null) {
            return;
        }
        for (String str : HttpUtils.values(value)) {
            for (String str2 : supportedSubProtocols) {
                if ("*".equals(str2) || str.equals(str2)) {
                    this.subProtocol = str;
                    return;
                }
            }
        }
    }

    private boolean addExtension(IExtension iExtension) {
        if (this.extensions == null) {
            this.extensions = new ArrayList();
            this.extensions.add(iExtension);
            return true;
        }
        if (getExtension(iExtension.getName()) != null) {
            return false;
        }
        this.extensions.add(iExtension);
        return true;
    }

    void acceptExtensions(HandshakeRequest handshakeRequest) throws HandshakeAcceptException {
        String value = handshakeRequest.getValue("Sec-WebSocket-Extensions");
        if (value == null || value.isEmpty()) {
            return;
        }
        try {
            List<String> values = HttpUtils.values(value);
            IExtension[] supportedExtensions = this.config.getSupportedExtensions();
            if (supportedExtensions != null) {
                HashSet hashSet = new HashSet();
                Iterator<String> it = values.iterator();
                while (it.hasNext()) {
                    List<String> extension = HandshakeUtils.extension(it.next());
                    for (IExtension iExtension : supportedExtensions) {
                        IExtension acceptOffer = iExtension.acceptOffer(extension);
                        if (acceptOffer != null && !hashSet.contains(acceptOffer.getGroupId())) {
                            hashSet.add(acceptOffer.getGroupId());
                            addExtension(acceptOffer);
                        }
                    }
                }
            }
        } catch (InvalidExtensionException e) {
            this.extensions = null;
            cause("Invalid websocket request extension");
            throw new HandshakeAcceptException(HttpStatus.BAD_REQUEST);
        }
    }

    void acceptUri(HandshakeRequest handshakeRequest) throws HandshakeAcceptException {
        try {
            URI uri = new URI(handshakeRequest.getUri());
            if (uri.isAbsolute()) {
                String str = HandshakeUtils.isHttp(uri) ? "ws" : HandshakeUtils.isHttps(uri) ? "wss" : null;
                if (str != null) {
                    uri = new URI(str + "://" + uri.getHost() + uri.getPath());
                }
            } else {
                String value = handshakeRequest.getValue("Host");
                if (value == null && !this.config.ignoreHostHeaderField()) {
                    cause("Missing websocket request host");
                    throw new HandshakeAcceptException(HttpStatus.BAD_REQUEST);
                }
                uri = new URI((this.session instanceof SSLSession ? "wss" : "ws") + "://" + value + handshakeRequest.getUri());
            }
            if (this.config.acceptRequestUri(uri)) {
                this.uri = uri;
            } else {
                cause("Unacceptable websocket request uri: " + uri);
                throw new HandshakeAcceptException(HttpStatus.NOT_FOUND);
            }
        } catch (URISyntaxException e) {
            cause("Invalid websocket request uri: " + e.getMessage());
            throw new HandshakeAcceptException(HttpStatus.BAD_REQUEST);
        }
    }

    HandshakeResponse accept(HandshakeRequest handshakeRequest) {
        try {
            acceptVersion(handshakeRequest);
            acceptBasicFields(handshakeRequest);
            acceptUri(handshakeRequest);
            acceptKey(handshakeRequest);
            acceptSubProtocol(handshakeRequest);
            acceptExtensions(handshakeRequest);
            HandshakeResponse handshakeResponse = new HandshakeResponse(HttpStatus.SWITCHING_PROTOCOLS);
            handshakeResponse.addValue("Upgrade", "websocket");
            handshakeResponse.addValue("Connection", "Upgrade");
            handshakeResponse.addValue("Sec-WebSocket-Accept", HandshakeUtils.generateAnswerKey(handshakeRequest.getValue("Sec-WebSocket-Key")));
            if (this.subProtocol != null) {
                handshakeResponse.addValue("Sec-WebSocket-Protocol", this.subProtocol);
            }
            if (this.extensions != null) {
                String[] strArr = new String[this.extensions.size()];
                int i = 0;
                Iterator<IExtension> it = this.extensions.iterator();
                while (it.hasNext()) {
                    int i2 = i;
                    i++;
                    strArr[i2] = HandshakeUtils.extension(it.next().response());
                }
                handshakeResponse.addValue("Sec-WebSocket-Extensions", HttpUtils.values(strArr));
            }
            this.config.customizeHeaders(handshakeResponse);
            return handshakeResponse;
        } catch (HandshakeAcceptException e) {
            return e.getResponse();
        }
    }

    static boolean contains(String str, String str2) {
        boolean z = false;
        Iterator<String> it = HttpUtils.values(str).iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (it.next().equalsIgnoreCase(str2)) {
                z = true;
                break;
            }
        }
        return z;
    }

    boolean validateBasicFields(HandshakeFrame handshakeFrame) {
        String value = handshakeFrame.getValue("Upgrade");
        String value2 = handshakeFrame.getValue("Connection");
        if (value == null) {
            cause("Missing websocket upgrade");
            return false;
        }
        if (value2 == null) {
            cause("Missing websocket connection");
            return false;
        }
        if (!contains(value, "websocket")) {
            cause("Invalid websocket upgrade: " + value);
            return false;
        }
        if (contains(value2, "Upgrade")) {
            return true;
        }
        cause("Invalid websocket connection: " + value2);
        return false;
    }

    boolean validateKeyChallenge(HandshakeResponse handshakeResponse) {
        String value = handshakeResponse.getValue("Sec-WebSocket-Accept");
        if (value == null) {
            cause("Missing websocket key challenge");
            return false;
        }
        String generateAnswerKey = HandshakeUtils.generateAnswerKey(this.key);
        if (value.equals(generateAnswerKey)) {
            return true;
        }
        cause("Invalid websocket key challenge. Actual: " + value + ". Expected: " + generateAnswerKey);
        return false;
    }

    boolean validateSubProtocol(HandshakeResponse handshakeResponse) {
        String value = handshakeResponse.getValue("Sec-WebSocket-Protocol");
        String[] supportedSubProtocols = this.config.getSupportedSubProtocols();
        if (supportedSubProtocols != null && supportedSubProtocols.length > 0) {
            if (value == null) {
                cause("Missing websocket sub protocol");
                return false;
            }
            for (String str : supportedSubProtocols) {
                if (value.equals(str)) {
                    this.subProtocol = value;
                    return true;
                }
            }
        } else if (value == null) {
            return true;
        }
        cause("Invalid websocket sub protocol: " + value);
        return false;
    }

    boolean validateExtensions(HandshakeResponse handshakeResponse) {
        String value = handshakeResponse.getValue("Sec-WebSocket-Extensions");
        IExtension[] supportedExtensions = this.config.getSupportedExtensions();
        if (supportedExtensions == null || supportedExtensions.length <= 0 || value == null) {
            return value == null;
        }
        List<String> values = HttpUtils.values(value);
        HashSet hashSet = new HashSet();
        Iterator<String> it = values.iterator();
        while (it.hasNext()) {
            List<String> extension = HandshakeUtils.extension(it.next());
            boolean z = false;
            int length = supportedExtensions.length;
            int i = 0;
            while (true) {
                if (i >= length) {
                    break;
                }
                try {
                    IExtension validateResponse = supportedExtensions[i].validateResponse(extension);
                    if (validateResponse == null) {
                        i++;
                    } else {
                        if (hashSet.contains(validateResponse.getGroupId()) || !addExtension(validateResponse)) {
                            this.extensions = null;
                            return false;
                        }
                        hashSet.add(validateResponse.getGroupId());
                        z = true;
                    }
                } catch (InvalidExtensionException e) {
                    cause("Invalid extension: " + e.getMessage());
                    this.extensions = null;
                    return false;
                }
            }
            if (!z) {
                this.extensions = null;
                return false;
            }
        }
        return true;
    }

    boolean validate(HandshakeResponse handshakeResponse) {
        if (handshakeResponse.getStatus() == HttpStatus.SWITCHING_PROTOCOLS.getStatus()) {
            return validateBasicFields(handshakeResponse) && validateKeyChallenge(handshakeResponse) && validateSubProtocol(handshakeResponse) && validateExtensions(handshakeResponse);
        }
        cause("Invalid websocket response status: " + handshakeResponse.getStatus());
        return false;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public HandshakeFrame handshake() {
        if (this.clientMode) {
            return request();
        }
        return null;
    }

    @Override // org.snf4j.websocket.handshake.IHandshaker
    public HandshakeFrame handshake(HandshakeFrame handshakeFrame) throws InvalidHandshakeException {
        if (this.clientMode) {
            if (handshakeFrame instanceof HandshakeResponse) {
                if (validate((HandshakeResponse) handshakeFrame)) {
                    this.finished = true;
                    return null;
                }
                this.closing = true;
                return null;
            }
        } else if (handshakeFrame instanceof HandshakeRequest) {
            HandshakeResponse accept = accept((HandshakeRequest) handshakeFrame);
            if (accept.getStatus() == HttpStatus.SWITCHING_PROTOCOLS.getStatus()) {
                this.finished = true;
            } else {
                this.closing = true;
            }
            return accept;
        }
        throw new InvalidHandshakeException();
    }
}
