/*
 * Decompiled with CFR 0.152.
 */
package io.undertow.protocols.ssl;

import io.undertow.UndertowMessages;
import java.io.IOException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLException;

final class SNISSLExplorer {
    public static final int RECORD_HEADER_SIZE = 5;

    private SNISSLExplorer() {
    }

    public static int getRequiredSize(ByteBuffer source) {
        ByteBuffer input = source.duplicate();
        if (input.remaining() < 5) {
            throw new BufferUnderflowException();
        }
        byte firstByte = input.get();
        input.get();
        byte thirdByte = input.get();
        if ((firstByte & 0x80) != 0 && thirdByte == 1) {
            return 5;
        }
        return ((input.get() & 0xFF) << 8 | input.get() & 0xFF) + 5;
    }

    public static int getRequiredSize(byte[] source, int offset, int length) throws IOException {
        ByteBuffer byteBuffer = ByteBuffer.wrap(source, offset, length).asReadOnlyBuffer();
        return SNISSLExplorer.getRequiredSize(byteBuffer);
    }

    public static List<SNIServerName> explore(ByteBuffer source) throws SSLException {
        ByteBuffer input = source.duplicate();
        if (input.remaining() < 5) {
            throw new BufferUnderflowException();
        }
        byte firstByte = input.get();
        byte secondByte = input.get();
        byte thirdByte = input.get();
        if ((firstByte & 0x80) != 0 && thirdByte == 1) {
            return Collections.emptyList();
        }
        if (firstByte == 22) {
            return SNISSLExplorer.exploreTLSRecord(input, firstByte, secondByte, thirdByte);
        }
        throw UndertowMessages.MESSAGES.notHandshakeRecord();
    }

    public static List<SNIServerName> explore(byte[] source, int offset, int length) throws IOException {
        ByteBuffer byteBuffer = ByteBuffer.wrap(source, offset, length).asReadOnlyBuffer();
        return SNISSLExplorer.explore(byteBuffer);
    }

    private static List<SNIServerName> exploreTLSRecord(ByteBuffer input, byte firstByte, byte secondByte, byte thirdByte) throws SSLException {
        if (firstByte != 22) {
            throw UndertowMessages.MESSAGES.notHandshakeRecord();
        }
        int recordLength = SNISSLExplorer.getInt16(input);
        if (recordLength > input.remaining()) {
            throw new BufferUnderflowException();
        }
        try {
            return SNISSLExplorer.exploreHandshake(input, secondByte, thirdByte, recordLength);
        }
        catch (BufferUnderflowException ignored) {
            throw UndertowMessages.MESSAGES.invalidHandshakeRecord();
        }
    }

    private static List<SNIServerName> exploreHandshake(ByteBuffer input, byte recordMajorVersion, byte recordMinorVersion, int recordLength) throws SSLException {
        byte handshakeType = input.get();
        if (handshakeType != 1) {
            throw UndertowMessages.MESSAGES.expectedClientHello();
        }
        int handshakeLength = SNISSLExplorer.getInt24(input);
        if (handshakeLength > recordLength - 4) {
            throw UndertowMessages.MESSAGES.multiRecordSSLHandshake();
        }
        input = input.duplicate();
        input.limit(handshakeLength + input.position());
        return SNISSLExplorer.exploreClientHello(input, recordMajorVersion, recordMinorVersion);
    }

    private static List<SNIServerName> exploreClientHello(ByteBuffer input, byte recordMajorVersion, byte recordMinorVersion) throws SSLException {
        ExtensionInfo info = null;
        input.get();
        input.get();
        int position = input.position();
        input.position(position + 32);
        SNISSLExplorer.ignoreByteVector8(input);
        for (int csLen = SNISSLExplorer.getInt16(input); csLen > 0; csLen -= 2) {
            SNISSLExplorer.getInt8(input);
            SNISSLExplorer.getInt8(input);
        }
        SNISSLExplorer.ignoreByteVector8(input);
        if (input.remaining() > 0) {
            info = SNISSLExplorer.exploreExtensions(input);
        }
        List<SNIServerName> snList = info != null ? info.sni : Collections.emptyList();
        return snList;
    }

    private static ExtensionInfo exploreExtensions(ByteBuffer input) throws SSLException {
        int extLen;
        List<SNIServerName> sni = Collections.emptyList();
        List<String> alpn = Collections.emptyList();
        for (int length = SNISSLExplorer.getInt16(input); length > 0; length -= extLen + 4) {
            int extType = SNISSLExplorer.getInt16(input);
            extLen = SNISSLExplorer.getInt16(input);
            if (extType == 0) {
                sni = SNISSLExplorer.exploreSNIExt(input, extLen);
                continue;
            }
            if (extType == 16) {
                alpn = SNISSLExplorer.exploreALPN(input, extLen);
                continue;
            }
            SNISSLExplorer.ignoreByteVector(input, extLen);
        }
        return new ExtensionInfo(sni, alpn);
    }

    private static List<String> exploreALPN(ByteBuffer input, int extLen) throws SSLException {
        ArrayList strings = new ArrayList();
        int rem = extLen;
        if (extLen >= 2) {
            int listLen = SNISSLExplorer.getInt16(input);
            if (listLen == 0 || listLen + 2 != extLen) {
                throw UndertowMessages.MESSAGES.invalidTlsExt();
            }
            rem -= 2;
            while (rem > 0) {
                int len = SNISSLExplorer.getInt8(input);
                if (len > rem) {
                    throw UndertowMessages.MESSAGES.notEnoughData();
                }
                byte[] b2 = new byte[len];
                input.get(b2);
                strings.add(new String(b2, StandardCharsets.UTF_8));
                rem -= len + 1;
            }
        }
        return strings.isEmpty() ? Collections.emptyList() : strings;
    }

    private static List<SNIServerName> exploreSNIExt(ByteBuffer input, int extLen) throws SSLException {
        LinkedHashMap<Integer, SNIServerName> sniMap = new LinkedHashMap<Integer, SNIServerName>();
        int remains = extLen;
        if (extLen >= 2) {
            int listLen = SNISSLExplorer.getInt16(input);
            if (listLen == 0 || listLen + 2 != extLen) {
                throw UndertowMessages.MESSAGES.invalidTlsExt();
            }
            remains -= 2;
            while (remains > 0) {
                SNIServerName serverName;
                int code = SNISSLExplorer.getInt8(input);
                int snLen = SNISSLExplorer.getInt16(input);
                if (snLen > remains) {
                    throw UndertowMessages.MESSAGES.notEnoughData();
                }
                byte[] encoded = new byte[snLen];
                input.get(encoded);
                switch (code) {
                    case 0: {
                        if (encoded.length == 0) {
                            throw UndertowMessages.MESSAGES.emptyHostNameSni();
                        }
                        serverName = new SNIHostName(encoded);
                        break;
                    }
                    default: {
                        serverName = new UnknownServerName(code, encoded);
                    }
                }
                if (sniMap.put(serverName.getType(), serverName) != null) {
                    throw UndertowMessages.MESSAGES.duplicatedSniServerName(serverName.getType());
                }
                remains -= encoded.length + 3;
            }
        } else if (extLen == 0) {
            throw UndertowMessages.MESSAGES.invalidTlsExt();
        }
        if (remains != 0) {
            throw UndertowMessages.MESSAGES.invalidTlsExt();
        }
        return Collections.unmodifiableList(new ArrayList(sniMap.values()));
    }

    private static int getInt8(ByteBuffer input) {
        return input.get();
    }

    private static int getInt16(ByteBuffer input) {
        return (input.get() & 0xFF) << 8 | input.get() & 0xFF;
    }

    private static int getInt24(ByteBuffer input) {
        return (input.get() & 0xFF) << 16 | (input.get() & 0xFF) << 8 | input.get() & 0xFF;
    }

    private static void ignoreByteVector8(ByteBuffer input) {
        SNISSLExplorer.ignoreByteVector(input, SNISSLExplorer.getInt8(input));
    }

    private static void ignoreByteVector16(ByteBuffer input) {
        SNISSLExplorer.ignoreByteVector(input, SNISSLExplorer.getInt16(input));
    }

    private static void ignoreByteVector24(ByteBuffer input) {
        SNISSLExplorer.ignoreByteVector(input, SNISSLExplorer.getInt24(input));
    }

    private static void ignoreByteVector(ByteBuffer input, int length) {
        if (length != 0) {
            int position = input.position();
            input.position(position + length);
        }
    }

    static final class ExtensionInfo {
        final List<SNIServerName> sni;
        final List<String> alpn;

        ExtensionInfo(List<SNIServerName> sni, List<String> alpn) {
            this.sni = sni;
            this.alpn = alpn;
        }
    }

    static final class UnknownServerName
    extends SNIServerName {
        UnknownServerName(int code, byte[] encoded) {
            super(code, encoded);
        }
    }
}

