package eu.clarussecure.proxy.protocol.plugins.pgsql.message;

import eu.clarussecure.proxy.protocol.plugins.pgsql.PgsqlConfiguration;
import eu.clarussecure.proxy.protocol.plugins.pgsql.PgsqlConstants;
import eu.clarussecure.proxy.protocol.plugins.pgsql.PgsqlSession;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.PgsqlMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.parser.PgsqlMessageParser;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.sql.EventProcessor;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.sql.SQLSession;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.writer.PgsqlMessageWriter;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.DefaultFullPgsqlRawMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.DefaultLastPgsqlRawContent;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.FullPgsqlRawMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.PgsqlRawMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.codec.MutablePgsqlRawMessage;
import eu.clarussecure.proxy.protocol.plugins.tcp.TCPConstants;
import eu.clarussecure.proxy.protocol.plugins.tcp.handler.forwarder.DirectedMessage;
import eu.clarussecure.proxy.spi.CString;
import eu.clarussecure.proxy.spi.buffer.MutableByteBufInputStream;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.util.ReferenceCountUtil;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:eu/clarussecure/proxy/protocol/plugins/pgsql/message/PgsqlMessageHandler.class */
public abstract class PgsqlMessageHandler<T extends PgsqlMessage> extends MessageToMessageDecoder<PgsqlRawMessage> {
    private static final Logger LOGGER = LoggerFactory.getLogger(PgsqlMessageHandler.class);
    protected final Map<Byte, Class<? extends T>> msgTypes;
    protected int numberOfPeerChannels = 0;
    protected int preferredPeerChannel = Integer.MIN_VALUE;

    /* JADX INFO: Access modifiers changed from: protected */
    @SafeVarargs
    public PgsqlMessageHandler(Class<? extends T>... clsArr) {
        this.msgTypes = (Map) Arrays.stream(clsArr).collect(Collectors.toMap(cls -> {
            try {
                return Byte.valueOf(cls.getField("TYPE").getByte(null));
            } catch (IllegalAccessException | IllegalArgumentException | NoSuchFieldException | SecurityException e) {
                LOGGER.error("Cannot read TYPE field of message class {}: ", cls.getSimpleName(), e);
                throw new IllegalArgumentException(String.format("Cannot read TYPE field of message class %s: ", cls.getSimpleName(), e));
            }
        }, cls2 -> {
            return cls2;
        }));
    }

    public boolean acceptInboundMessage(Object obj) throws Exception {
        if (!super.acceptInboundMessage(obj)) {
            return false;
        }
        if ((obj instanceof FullPgsqlRawMessage) || (obj instanceof MutablePgsqlRawMessage)) {
            return this.msgTypes.keySet().stream().anyMatch(b -> {
                return ((PgsqlRawMessage) obj).getType() == b.byteValue();
            });
        }
        return false;
    }

    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        if (obj instanceof DirectedMessage) {
            Object msg = ((DirectedMessage) obj).getMsg();
            if (acceptInboundMessage(msg)) {
                obj = msg;
            }
        }
        super.channelRead(channelHandlerContext, obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected void decode(ChannelHandlerContext channelHandlerContext, PgsqlRawMessage pgsqlRawMessage, List<Object> list) throws Exception {
        if (isStreamingSupported(pgsqlRawMessage.getType()) && (pgsqlRawMessage instanceof MutablePgsqlRawMessage) && !((MutablePgsqlRawMessage) pgsqlRawMessage).isComplete()) {
            LOGGER.trace("Decoding raw message in streaming mode: {}...", pgsqlRawMessage);
            decodeStream(channelHandlerContext, pgsqlRawMessage);
            LOGGER.trace("Full raw message decoded: {}", pgsqlRawMessage);
            return;
        }
        LOGGER.trace("Decoding full raw message: {}...", pgsqlRawMessage);
        PgsqlMessage decode = decode(channelHandlerContext, pgsqlRawMessage.getType(), pgsqlRawMessage.getContent());
        LOGGER.trace("PGSQL message decoded: {}", decode);
        if (getNumberOfPeerChannels(channelHandlerContext) <= 1) {
            PgsqlMessage process = process(channelHandlerContext, decode);
            if (process == null) {
                LOGGER.trace("Full raw message consumed {}...", pgsqlRawMessage);
                return;
            }
            if (process != decode) {
                LOGGER.trace("Encoding modified PGSQL message {}...", process);
                ByteBuf encode = encode(channelHandlerContext, process, allocate(channelHandlerContext, process, pgsqlRawMessage.getBytes()));
                pgsqlRawMessage = new DefaultFullPgsqlRawMessage(encode, process.getType(), encode.capacity());
                LOGGER.trace("Full raw message encoded: {}", pgsqlRawMessage);
            } else {
                ReferenceCountUtil.retain(pgsqlRawMessage);
            }
            pgsqlRawMessage.filter(false);
            list.add(pgsqlRawMessage);
            LOGGER.trace("Full raw message retained in the pipeline : {}", pgsqlRawMessage);
            return;
        }
        List<DirectedMessage> directedProcess = directedProcess(channelHandlerContext, decode);
        if (directedProcess == null) {
            LOGGER.trace("Full raw message consumed {}...", pgsqlRawMessage);
            return;
        }
        ByteBuf bytes = pgsqlRawMessage.getBytes();
        for (DirectedMessage directedMessage : directedProcess) {
            int to = directedMessage.getTo();
            PgsqlMessage pgsqlMessage = (PgsqlMessage) directedMessage.getMsg();
            if (pgsqlMessage != decode) {
                LOGGER.trace("Encoding modified PGSQL message {}...", pgsqlMessage);
                ByteBuf encode2 = encode(channelHandlerContext, pgsqlMessage, allocate(channelHandlerContext, pgsqlMessage, bytes));
                pgsqlRawMessage = new DefaultFullPgsqlRawMessage(encode2, pgsqlMessage.getType(), encode2.capacity());
                LOGGER.trace("Full raw message encoded: {}", pgsqlRawMessage);
                bytes = null;
            } else {
                ReferenceCountUtil.retain(pgsqlRawMessage);
            }
            list.add(new DirectedMessage(to, pgsqlRawMessage));
            LOGGER.trace("Full raw message retained in the pipeline : {}", pgsqlRawMessage);
        }
    }

    protected void decodeStream(ChannelHandlerContext channelHandlerContext, PgsqlRawMessage pgsqlRawMessage) throws IOException {
        LOGGER.trace("Creating input stream...");
        MutableByteBufInputStream mutableByteBufInputStream = new MutableByteBufInputStream(pgsqlRawMessage.getBytes(), pgsqlRawMessage.getTotalLength());
        Throwable th = null;
        try {
            try {
                LOGGER.trace("Input stream created to read from {}", pgsqlRawMessage);
                decodeStream(channelHandlerContext, pgsqlRawMessage.getType(), mutableByteBufInputStream);
                if (mutableByteBufInputStream != null) {
                    if (0 == 0) {
                        mutableByteBufInputStream.close();
                        return;
                    }
                    try {
                        mutableByteBufInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (mutableByteBufInputStream != null) {
                if (th != null) {
                    try {
                        mutableByteBufInputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    mutableByteBufInputStream.close();
                }
            }
            throw th4;
        }
    }

    protected boolean isStreamingSupported(byte b) {
        return false;
    }

    protected void decodeStream(ChannelHandlerContext channelHandlerContext, byte b, MutableByteBufInputStream mutableByteBufInputStream) throws IOException {
        throw new UnsupportedOperationException("Unsupported decoding from input stream");
    }

    protected T decode(ChannelHandlerContext channelHandlerContext, byte b, ByteBuf byteBuf) throws IOException {
        Class<? extends T> cls = this.msgTypes.get(Byte.valueOf(b));
        if (cls == null) {
            LOGGER.error("Unsupported decoding of full raw message for type {}", Byte.valueOf(b));
            throw new UnsupportedOperationException(String.format("Unsupported decoding of full raw message for type %d", Byte.valueOf(b)));
        }
        PgsqlMessageParser<M> parser = getParser(channelHandlerContext, cls);
        byteBuf.markReaderIndex();
        T t = (T) parser.parse(byteBuf);
        byteBuf.resetReaderIndex();
        return t;
    }

    protected List<DirectedMessage<T>> directedProcess(ChannelHandlerContext channelHandlerContext, T t) throws IOException {
        T process = process(channelHandlerContext, t);
        if (process == null) {
            return null;
        }
        return Collections.singletonList(new DirectedMessage(getPreferredPeerChannel(channelHandlerContext), process));
    }

    protected T process(ChannelHandlerContext channelHandlerContext, T t) throws IOException {
        List<DirectedMessage<T>> directedProcess = directedProcess(channelHandlerContext, t);
        if (directedProcess != null && directedProcess.size() > 1) {
            throw new IllegalStateException(String.format("%d new messages, 1 expected", Integer.valueOf(directedProcess.size())));
        }
        if (directedProcess != null) {
            return (T) directedProcess.get(0).getMsg();
        }
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected ByteBuf allocate(ChannelHandlerContext channelHandlerContext, T t, ByteBuf byteBuf) {
        PgsqlMessageWriter writer = getWriter(channelHandlerContext, t.getClass());
        if (writer != null) {
            return writer.allocate(t, byteBuf);
        }
        LOGGER.error("Unsupported allocating buffer for {} message", t.getClass().getSimpleName());
        throw new UnsupportedOperationException(String.format("Unsupported allocating buffer for %s message", t.getClass().getSimpleName()));
    }

    protected ByteBuf encode(ChannelHandlerContext channelHandlerContext, T t) throws IOException {
        return encode(channelHandlerContext, t, null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected ByteBuf encode(ChannelHandlerContext channelHandlerContext, T t, ByteBuf byteBuf) throws IOException {
        PgsqlMessageWriter writer = getWriter(channelHandlerContext, t.getClass());
        if (writer != null) {
            return writer.write(t, byteBuf);
        }
        LOGGER.error("Unsupported encoding of {} message", t.getClass().getSimpleName());
        throw new UnsupportedOperationException(String.format("Unsupported encoding of %s message", t.getClass().getSimpleName()));
    }

    protected <M extends T> PgsqlMessageParser<M> getParser(ChannelHandlerContext channelHandlerContext, Class<? extends T> cls) {
        Map map = (Map) channelHandlerContext.channel().attr(PgsqlConstants.MSG_PARSERS_KEY).get();
        if (map == null) {
            map = new HashMap();
            channelHandlerContext.channel().attr(PgsqlConstants.MSG_PARSERS_KEY).set(map);
        }
        PgsqlMessageParser<M> pgsqlMessageParser = (PgsqlMessageParser) map.get(cls);
        if (pgsqlMessageParser == null) {
            pgsqlMessageParser = (PgsqlMessageParser) buildParserWriter(cls, true);
            map.put(cls, pgsqlMessageParser);
        }
        return pgsqlMessageParser;
    }

    protected <M extends PgsqlMessage> PgsqlMessageWriter<M> getWriter(ChannelHandlerContext channelHandlerContext, Class<? extends PgsqlMessage> cls) {
        Map map = (Map) channelHandlerContext.channel().attr(PgsqlConstants.MSG_WRITERS_KEY).get();
        if (map == null) {
            map = new HashMap();
            channelHandlerContext.channel().attr(PgsqlConstants.MSG_WRITERS_KEY).set(map);
        }
        PgsqlMessageWriter<M> pgsqlMessageWriter = (PgsqlMessageWriter) map.get(cls);
        if (pgsqlMessageWriter == null) {
            pgsqlMessageWriter = (PgsqlMessageWriter) buildParserWriter(cls, false);
            map.put(cls, pgsqlMessageWriter);
        }
        return pgsqlMessageWriter;
    }

    private static <WP> WP buildParserWriter(Class<? extends PgsqlMessage> cls, boolean z) {
        String simpleName = cls.getSimpleName();
        String name = PgsqlMessage.class.getPackage().getName();
        String str = z ? "Parser" : "Writer";
        try {
            return (WP) PgsqlMessage.class.getClassLoader().loadClass(name + "." + str.toLowerCase() + "." + simpleName + str).newInstance();
        } catch (ClassNotFoundException | IllegalAccessException | InstantiationException e) {
            throw new IllegalArgumentException(e);
        }
    }

    protected int getNumberOfPeerChannels(ChannelHandlerContext channelHandlerContext) {
        if (this.numberOfPeerChannels == 0) {
            if (channelHandlerContext.channel() == getPgsqlSession(channelHandlerContext).getClientSideChannel()) {
                this.numberOfPeerChannels = getPsqlConfiguration(channelHandlerContext).getServerEndpoints().size();
            } else {
                this.numberOfPeerChannels = 1;
            }
        }
        return this.numberOfPeerChannels;
    }

    protected int getPreferredPeerChannel(ChannelHandlerContext channelHandlerContext) {
        if (this.preferredPeerChannel == Integer.MIN_VALUE) {
            PgsqlSession pgsqlSession = getPgsqlSession(channelHandlerContext);
            if (channelHandlerContext.channel() == pgsqlSession.getClientSideChannel()) {
                Integer num = (Integer) channelHandlerContext.channel().attr(TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY).get();
                if (num == null) {
                    throw new NullPointerException(TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY.name() + " is not set");
                }
                if (num.intValue() < 0 || num.intValue() >= pgsqlSession.getServerSideChannels().size()) {
                    throw new IndexOutOfBoundsException(String.format("invalid %s: value: %d, number of server endpoints: %d ", TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY.name(), num, Integer.valueOf(pgsqlSession.getServerSideChannels().size())));
                }
                this.preferredPeerChannel = num.intValue();
            } else {
                this.preferredPeerChannel = 0;
            }
        }
        return this.preferredPeerChannel;
    }

    protected PgsqlConfiguration getPsqlConfiguration(ChannelHandlerContext channelHandlerContext) {
        return (PgsqlConfiguration) channelHandlerContext.channel().attr(PgsqlConstants.CONFIGURATION_KEY).get();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public PgsqlSession getPgsqlSession(ChannelHandlerContext channelHandlerContext) {
        return (PgsqlSession) channelHandlerContext.channel().attr(PgsqlConstants.SESSION_KEY).get();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SQLSession getSqlSession(ChannelHandlerContext channelHandlerContext) {
        return getPgsqlSession(channelHandlerContext).getSqlSession();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public EventProcessor getEventProcessor(ChannelHandlerContext channelHandlerContext) {
        return getPgsqlSession(channelHandlerContext).getEventProcessor();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void sendErrorResponse(ChannelHandlerContext channelHandlerContext, Map<Byte, CString> map) throws IOException {
        sendResponse(channelHandlerContext, new PgsqlErrorMessage(map));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public <M extends PgsqlMessage> void sendResponse(ChannelHandlerContext channelHandlerContext, M m) throws IOException {
        PgsqlMessageWriter writer = getWriter(channelHandlerContext, m.getClass());
        channelHandlerContext.channel().writeAndFlush(new DefaultLastPgsqlRawContent(writer.write(m, writer.allocate(m))));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public <M extends PgsqlMessage> void sendRequest(ChannelHandlerContext channelHandlerContext, M m, int i) throws IOException {
        PgsqlMessageWriter writer = getWriter(channelHandlerContext, m.getClass());
        DefaultLastPgsqlRawContent defaultLastPgsqlRawContent = new DefaultLastPgsqlRawContent(writer.write(m, writer.allocate(m)));
        if (i != -1) {
            getPgsqlSession(channelHandlerContext).getServerSideChannel(i).writeAndFlush(defaultLastPgsqlRawContent);
            return;
        }
        for (int i2 = 0; i2 < getPgsqlSession(channelHandlerContext).getServerSideChannels().size(); i2++) {
            Channel serverSideChannel = getPgsqlSession(channelHandlerContext).getServerSideChannel(i2);
            if (i2 < getPgsqlSession(channelHandlerContext).getServerSideChannels().size() - 1) {
                serverSideChannel.writeAndFlush(defaultLastPgsqlRawContent.retainedDuplicate());
            } else {
                serverSideChannel.writeAndFlush(defaultLastPgsqlRawContent);
            }
        }
    }

    protected /* bridge */ /* synthetic */ void decode(ChannelHandlerContext channelHandlerContext, Object obj, List list) throws Exception {
        decode(channelHandlerContext, (PgsqlRawMessage) obj, (List<Object>) list);
    }
}
