package org.logdoc.fairhttp.service.http;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.function.Consumer;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;
import org.logdoc.fairhttp.service.api.helpers.Headers;
import org.logdoc.fairhttp.service.http.Http;
import org.logdoc.fairhttp.service.tools.websocket.Opcode;
import org.logdoc.fairhttp.service.tools.websocket.extension.DefaultExtension;
import org.logdoc.fairhttp.service.tools.websocket.extension.ExtensionRequestData;
import org.logdoc.fairhttp.service.tools.websocket.extension.IExtension;
import org.logdoc.fairhttp.service.tools.websocket.frames.AFrame;
import org.logdoc.fairhttp.service.tools.websocket.frames.BinaryFrame;
import org.logdoc.fairhttp.service.tools.websocket.frames.CloseFrame;
import org.logdoc.fairhttp.service.tools.websocket.frames.Frame;
import org.logdoc.fairhttp.service.tools.websocket.frames.PingFrame;
import org.logdoc.fairhttp.service.tools.websocket.frames.TextFrame;
import org.logdoc.fairhttp.service.tools.websocket.protocol.IProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Document;
import org.xml.sax.InputSource;

/* loaded from: input_file:org/logdoc/fairhttp/service/http/WebSocket.class */
public abstract class WebSocket extends Response implements Consumer<Byte> {
    private static final Logger logger = LoggerFactory.getLogger(WebSocket.class);
    private IExtension extension;
    private IProtocol protocol;
    private ObjectMapper om;
    private DocumentBuilder xb;
    private Transformer tr;
    private int frameStage;
    private int payloadlength;
    private AFrame frame;
    private Frame incompleteframe;
    private Opcode optcode;
    private boolean mask;
    private Http.Drive drive;
    private byte[] payload;
    private byte[] maskkey;
    private Consumer<byte[]> writeConsumer;

    public WebSocket(Request request) {
        super(101, "Websocket Connection Upgrade");
        header(Headers.Upgrade, "websocket");
        header(Headers.Connection, Headers.Upgrade);
        prepare(request);
    }

    public WebSocket(Request request, IExtension iExtension) {
        this(request, iExtension, null);
    }

    public WebSocket(Request request, IExtension iExtension, IProtocol iProtocol) {
        super(101, "Websocket Connection Upgrade");
        header(Headers.Upgrade, "websocket");
        header(Headers.Connection, Headers.Upgrade);
        this.extension = iExtension;
        this.protocol = iProtocol;
        prepare(request);
    }

    private void prepare(Request request) {
        if (!IProtocol.WS_VERSION.equals(request.header(Headers.SecWebsocketVersion))) {
            throw new IllegalStateException("Wrong websocket version: " + request.header(Headers.SecWebsocketVersion) + ", expected: 13");
        }
        try {
            header(Headers.SecWebsocketAccept, Base64.getEncoder().encodeToString(MessageDigest.getInstance("SHA1").digest((request.header(Headers.SecWebsocketKey) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").getBytes())));
            if (this.extension != null && this.extension.acceptProvidedExtensionAsServer(request.header(Headers.SecWebsocketExtensions))) {
                header(Headers.SecWebsocketExtensions, this.extension.getProvidedExtensionAsServer());
            } else if (this.extension != null) {
                throw new IllegalStateException("Cant accept requested extenstion(s): " + request.header(Headers.SecWebsocketExtensions));
            }
            if (this.protocol != null && this.protocol.acceptProtocol(request.header(Headers.SecWebsocketProtocols))) {
                header(Headers.SecWebsocketProtocols, this.protocol.getProvidedProtocol());
            } else if (this.protocol != null) {
                throw new IllegalStateException("Cant accept requested protocol(s): " + request.header(Headers.SecWebsocketProtocols));
            }
            if (this.extension == null) {
                this.extension = new DefaultExtension();
            }
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // java.util.function.Consumer
    public final void accept(Byte b) {
        if (b == null || b.byteValue() == -1) {
            return;
        }
        int i = this.frameStage;
        this.frameStage = i + 1;
        switch (i) {
            case CloseFrame.NEVER_CONNECTED /* -1 */:
                this.frameStage = -1;
                this.drive.accept(b);
                return;
            case 0:
                this.optcode = toOpcode((byte) (b.byteValue() & 15));
                this.frame = AFrame.get(this.optcode);
                this.frame.setFin((b.byteValue() >> 8) != 0);
                this.frame.setRSV1((b.byteValue() & 64) != 0);
                this.frame.setRSV2((b.byteValue() & 32) != 0);
                this.frame.setRSV3((b.byteValue() & 16) != 0);
                return;
            case 1:
                this.mask = (b.byteValue() & Byte.MIN_VALUE) != 0;
                this.payloadlength = (byte) (b.byteValue() & Byte.MAX_VALUE);
                if (this.payloadlength > 125) {
                    if (this.optcode == Opcode.PING || this.optcode == Opcode.PONG || this.optcode == Opcode.CLOSING) {
                        throw new IllegalArgumentException("more than 125 octets");
                    }
                    if (this.payloadlength == 126) {
                        this.drive = new Http.Drive(2, bArr -> {
                            this.frameStage = 2;
                            this.payloadlength = new BigInteger(new byte[]{0, bArr[0], bArr[1]}).intValue();
                        });
                    } else {
                        this.drive = new Http.Drive(8, bArr2 -> {
                            this.frameStage = 2;
                            this.payloadlength = (int) new BigInteger(bArr2).longValue();
                        });
                    }
                    this.frameStage = -1;
                    return;
                }
                return;
            case 2:
                this.payload = new byte[this.payloadlength];
                this.frameStage = -1;
                if (this.mask) {
                    this.drive = new Http.Drive(4, bArr3 -> {
                        this.maskkey = bArr3;
                        this.drive = new Http.Drive(this.payloadlength, bArr3 -> {
                            this.frameStage = 3;
                            for (int i2 = 0; i2 < this.payloadlength; i2++) {
                                this.payload[i2] = (byte) (bArr3[i2] ^ this.maskkey[i2 % 4]);
                            }
                        });
                    });
                    return;
                } else {
                    this.drive = new Http.Drive(this.payloadlength, bArr4 -> {
                        this.frameStage = 3;
                        this.payload = bArr4;
                    });
                    return;
                }
            case 3:
                this.frameStage = 0;
                this.frame.setPayload(this.payload);
                IExtension iExtension = null;
                if (this.frame.getOpcode() != Opcode.CONTINUOUS && (this.frame.isRSV1() || this.frame.isRSV2() || this.frame.isRSV3())) {
                    iExtension = this.extension;
                }
                if (iExtension == null) {
                    iExtension = new DefaultExtension();
                }
                if (!iExtension.isFrameValid(this.frame)) {
                    logger.error("Extension cant decode frame: " + this.frame);
                    return;
                }
                try {
                    iExtension.decodeFrame(this.frame);
                    if (this.frame.isValid()) {
                        process(this.frame);
                    } else {
                        logger.error("Invalid frame catched: " + this.frame);
                    }
                    return;
                } catch (Exception e) {
                    logger.error("Frame processing error: " + this.frame + " :: " + e.getMessage(), e);
                    return;
                }
            default:
                return;
        }
    }

    private void process(Frame frame) {
        Opcode opcode = frame.getOpcode();
        if (opcode == Opcode.CLOSING) {
            int i = 1005;
            String str = ExtensionRequestData.EMPTY_VALUE;
            if (frame instanceof CloseFrame) {
                i = ((CloseFrame) frame).getCloseCode();
                str = ((CloseFrame) frame).getMessage();
            }
            close(i, str, true);
            return;
        }
        if (opcode == Opcode.PING) {
            onPing();
            return;
        }
        if (opcode == Opcode.PONG) {
            onPong();
            return;
        }
        if (!frame.isFin() || opcode == Opcode.CONTINUOUS) {
            processFrameContinuousAndNonFin(frame, opcode);
        } else {
            if (this.incompleteframe != null) {
                throw new IllegalStateException("Continuous frame sequence not completed.");
            }
            frameReady(frame);
        }
    }

    private void frameReady(Frame frame) {
        byte[] payloadData = frame.getPayloadData();
        if (frame.getOpcode() != Opcode.TEXT) {
            if (frame.getOpcode() == Opcode.BINARY) {
                onBytes(payloadData);
                return;
            }
            return;
        }
        String trim = new String(payloadData, StandardCharsets.UTF_8).trim();
        if ((trim.startsWith("{") && trim.endsWith("}")) || (trim.startsWith("[") && trim.endsWith("]"))) {
            if (this.om == null) {
                this.om = new ObjectMapper();
            }
            try {
                onJson(this.om.readTree(trim));
                return;
            } catch (Exception e) {
            }
        }
        if (trim.toLowerCase().startsWith("<") && trim.endsWith(">")) {
            if (this.xb == null) {
                try {
                    this.xb = DocumentBuilderFactory.newInstance().newDocumentBuilder();
                } catch (Exception e2) {
                }
            }
            if (this.xb != null) {
                try {
                    onXml(this.xb.parse(new InputSource(new ByteArrayInputStream(payloadData))));
                    return;
                } catch (Exception e3) {
                }
            }
        }
        onText(trim);
    }

    private void processFrameContinuousAndNonFin(Frame frame, Opcode opcode) {
        if (opcode != Opcode.CONTINUOUS) {
            this.incompleteframe = frame;
        } else if (frame.isFin()) {
            if (this.incompleteframe == null) {
                throw new IllegalStateException("Continuous frame sequence was not started.");
            }
            this.incompleteframe.append(frame);
            ((AFrame) this.incompleteframe).isValid();
            frameReady(this.incompleteframe);
            this.incompleteframe = null;
        } else if (this.incompleteframe == null) {
            throw new IllegalStateException("Continuous frame sequence was not started.");
        }
        if (opcode != Opcode.CONTINUOUS || this.incompleteframe == null) {
            return;
        }
        this.incompleteframe.append(frame);
    }

    private Opcode toOpcode(byte b) {
        switch (b) {
            case 0:
                return Opcode.CONTINUOUS;
            case 1:
                return Opcode.TEXT;
            case 2:
                return Opcode.BINARY;
            case 3:
            case 4:
            case 5:
            case 6:
            case 7:
            default:
                throw new IllegalArgumentException("Unknown opcode " + b);
            case 8:
                return Opcode.CLOSING;
            case 9:
                return Opcode.PING;
            case 10:
                return Opcode.PONG;
        }
    }

    public final void close() {
        close(CloseFrame.NORMAL, null, false);
    }

    public final void close(int i, String str) {
        close(i, str, false);
    }

    private void close(int i, String str, boolean z) {
        if (!z) {
            try {
                sendFrame(new CloseFrame(i, str));
            } catch (Exception e) {
            }
        }
        try {
            onClose(i, str, z);
        } catch (Exception e2) {
        }
    }

    public abstract void onJson(JsonNode jsonNode);

    public abstract void onXml(Document document);

    public abstract void onText(String str);

    public abstract void onBytes(byte[] bArr);

    public abstract void onPing();

    public abstract void onPong();

    public abstract void onClose(int i, String str, boolean z);

    public void ping() {
        sendFrame(new PingFrame());
    }

    public void send(JsonNode jsonNode) {
        if (jsonNode == null) {
            throw new NullPointerException("Message");
        }
        if (this.om == null) {
            this.om = new ObjectMapper();
        }
        try {
            TextFrame textFrame = new TextFrame();
            textFrame.setPayload(jsonNode.toString().getBytes(StandardCharsets.UTF_8));
            textFrame.setMasked(true);
            sendFrame(textFrame);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    public void send(Document document) {
        if (document == null) {
            throw new NullPointerException("Message");
        }
        if (this.tr == null) {
            try {
                this.tr = TransformerFactory.newInstance().newTransformer();
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(16384);
            try {
                this.tr.transform(new DOMSource(document), new StreamResult(byteArrayOutputStream));
                byteArrayOutputStream.flush();
                TextFrame textFrame = new TextFrame();
                textFrame.setPayload(byteArrayOutputStream.toByteArray());
                textFrame.setMasked(true);
                sendFrame(textFrame);
                byteArrayOutputStream.close();
            } finally {
            }
        } catch (IOException | TransformerException e2) {
            throw new IllegalStateException(e2);
        }
    }

    public void send(String str) {
        if (str == null) {
            throw new NullPointerException("Message");
        }
        TextFrame textFrame = new TextFrame();
        textFrame.setPayload(str.getBytes(StandardCharsets.UTF_8));
        textFrame.setMasked(true);
        sendFrame(textFrame);
    }

    public void send(byte[] bArr) {
        if (bArr == null) {
            throw new NullPointerException("Message");
        }
        BinaryFrame binaryFrame = new BinaryFrame();
        binaryFrame.setMasked(true);
        binaryFrame.setPayload(bArr);
        sendFrame(binaryFrame);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public synchronized void sendFrame(AFrame aFrame) {
        if (aFrame == null) {
            throw new NullPointerException("Frame");
        }
        if (!aFrame.isValid()) {
            throw new IllegalStateException("Invalid frame");
        }
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(4096);
            try {
                this.extension.encodeFrame(aFrame);
                byte[] payloadData = aFrame.getPayloadData();
                int sizeBytes = getSizeBytes(payloadData);
                byte fromOpcode = (byte) (((byte) (aFrame.isFin() ? -128 : 0)) | fromOpcode(aFrame.getOpcode()));
                if (aFrame.isRSV1()) {
                    fromOpcode = (byte) (fromOpcode | getRSVByte(1));
                }
                if (aFrame.isRSV2()) {
                    fromOpcode = (byte) (fromOpcode | getRSVByte(2));
                }
                if (aFrame.isRSV3()) {
                    fromOpcode = (byte) (fromOpcode | getRSVByte(3));
                }
                byteArrayOutputStream.write(fromOpcode);
                byte[] byteArray = toByteArray(payloadData.length, sizeBytes);
                if (sizeBytes == 1) {
                    byteArrayOutputStream.write(byteArray[0]);
                } else if (sizeBytes == 2) {
                    byteArrayOutputStream.write(126);
                    byteArrayOutputStream.write(byteArray);
                } else {
                    if (sizeBytes != 8) {
                        throw new IllegalStateException("Size representation not supported/specified");
                    }
                    byteArrayOutputStream.write(127);
                    byteArrayOutputStream.write(byteArray);
                }
                byteArrayOutputStream.write(payloadData);
                byteArrayOutputStream.flush();
                this.writeConsumer.accept(byteArrayOutputStream.toByteArray());
                byteArrayOutputStream.close();
            } catch (Throwable th) {
                try {
                    byteArrayOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        } catch (IOException e) {
            close(CloseFrame.ABNORMAL_CLOSE, e.getMessage(), false);
            throw new IllegalStateException(e);
        } catch (IllegalStateException e2) {
            throw e2;
        } catch (Exception e3) {
            throw new IllegalStateException(e3);
        }
    }

    private int getSizeBytes(byte[] bArr) {
        if (bArr.length <= 125) {
            return 1;
        }
        return bArr.length <= 65535 ? 2 : 8;
    }

    private byte getRSVByte(int i) {
        switch (i) {
            case 1:
                return (byte) 64;
            case 2:
                return (byte) 32;
            case 3:
                return (byte) 16;
            default:
                return (byte) 0;
        }
    }

    private byte[] toByteArray(long j, int i) {
        byte[] bArr = new byte[i];
        int i2 = (8 * i) - 8;
        for (int i3 = 0; i3 < i; i3++) {
            bArr[i3] = (byte) (j >>> (i2 - (8 * i3)));
        }
        return bArr;
    }

    private byte fromOpcode(Opcode opcode) {
        switch (opcode) {
            case CONTINUOUS:
                return (byte) 0;
            case TEXT:
                return (byte) 1;
            case BINARY:
                return (byte) 2;
            case CLOSING:
                return (byte) 8;
            case PING:
                return (byte) 9;
            case PONG:
                return (byte) 10;
            default:
                throw new IllegalArgumentException("Don't know how to handle " + opcode);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setWriteHandler(Consumer<byte[]> consumer) {
        this.writeConsumer = consumer;
    }
}
