package org.drasyl.node.handler.crypto;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.function.LongSupplier;
import java.util.function.Supplier;
import org.drasyl.crypto.Crypto;
import org.drasyl.crypto.CryptoException;
import org.drasyl.identity.Identity;
import org.drasyl.identity.IdentityPublicKey;
import org.drasyl.identity.KeyAgreementPublicKey;
import org.drasyl.node.event.LongTimeEncryptionEvent;
import org.drasyl.node.event.Peer;
import org.drasyl.node.event.PerfectForwardSecrecyEncryptionEvent;
import org.drasyl.util.logging.Logger;
import org.drasyl.util.logging.LoggerFactory;

/* loaded from: input_file:org/drasyl/node/handler/crypto/PFSArmHandler.class */
public class PFSArmHandler extends AbstractArmHandler {
    private static final Logger LOG = LoggerFactory.getLogger(PFSArmHandler.class);
    private final Duration retryInterval;
    private final LongSupplier expireProvider;
    private State state;

    /* loaded from: input_file:org/drasyl/node/handler/crypto/PFSArmHandler$State.class */
    protected enum State {
        LONG_TIME,
        PFS
    }

    protected PFSArmHandler(Crypto crypto, Identity identity, IdentityPublicKey identityPublicKey, Session session, LongSupplier longSupplier, Duration duration, State state) {
        super(crypto, identity, identityPublicKey, session);
        this.expireProvider = longSupplier;
        this.retryInterval = duration;
        this.state = state;
    }

    public PFSArmHandler(Crypto crypto, Duration duration, Duration duration2, int i, Identity identity, IdentityPublicKey identityPublicKey) throws CryptoException {
        super(crypto, duration, i, identity, identityPublicKey);
        this.retryInterval = duration2;
        this.expireProvider = () -> {
            return System.currentTimeMillis() + duration.toMillis();
        };
        this.state = State.LONG_TIME;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.drasyl.node.handler.crypto.AbstractArmHandler
    public void encode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
        channelHandlerContext.executor().execute(() -> {
            checkForRenewAgreement(channelHandlerContext);
        });
        super.encode(channelHandlerContext, byteBuf, list);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.drasyl.node.handler.crypto.AbstractArmHandler
    public void decode(ChannelHandlerContext channelHandlerContext, ArmHeader armHeader, List<Object> list) throws Exception {
        channelHandlerContext.executor().execute(() -> {
            checkForRenewAgreement(channelHandlerContext);
        });
        receivedAck(channelHandlerContext, armHeader.getAgreementId());
        super.decode(channelHandlerContext, armHeader, list);
    }

    @Override // org.drasyl.node.handler.crypto.AbstractArmHandler
    protected void inboundArmMessage(ChannelHandlerContext channelHandlerContext, Object obj) {
        if (obj instanceof AcknowledgementMessage) {
            receivedAck(channelHandlerContext, ((AcknowledgementMessage) obj).getAgreementId());
        } else if (obj instanceof KeyExchangeMessage) {
            receivedKeyExchangeMessage(channelHandlerContext, ((KeyExchangeMessage) obj).getSessionKey());
        }
    }

    @Override // org.drasyl.node.handler.crypto.AbstractArmHandler
    protected void onNonAgreement(ChannelHandlerContext channelHandlerContext) {
        channelHandlerContext.executor().execute(() -> {
            doKeyExchange(channelHandlerContext);
        });
    }

    private void checkForRenewAgreement(ChannelHandlerContext channelHandlerContext) {
        Agreement agreement = (Agreement) this.session.getCurrentActiveAgreement().computeOnCondition(agreement2 -> {
            return agreement2 != null && agreement2.isStale();
        }, agreement3 -> {
            this.session.getInitializedAgreements().remove(agreement3.getAgreementId());
            if (this.state != State.PFS) {
                return null;
            }
            channelHandlerContext.fireUserEventTriggered(LongTimeEncryptionEvent.of(Peer.of(this.peerIdentity)));
            this.state = State.LONG_TIME;
            return null;
        }).orElse(null);
        if (agreement == null || !agreement.isRenewable() || this.session.getLastRenewAttemptAt() >= System.currentTimeMillis() - this.retryInterval.toMillis()) {
            return;
        }
        this.session.setLastRenewAttemptAt(System.currentTimeMillis());
        doKeyExchange(channelHandlerContext);
    }

    private void receivedKeyExchangeMessage(ChannelHandlerContext channelHandlerContext, KeyAgreementPublicKey keyAgreementPublicKey) {
        channelHandlerContext.executor().execute(() -> {
            Logger logger = LOG;
            Channel channel = channelHandlerContext.channel();
            Objects.requireNonNull(channel);
            logger.trace("[{}] Received key exchange message", channel::id);
            if (!keyAgreementPublicKey.equals(this.peerIdentity.getLongTimeKeyAgreementKey())) {
                computeInactiveAgreementIfNeeded().setRecipientsKeyAgreementKey(keyAgreementPublicKey);
                doKeyExchange(channelHandlerContext);
                sendAck(channelHandlerContext);
            } else {
                Logger logger2 = LOG;
                Channel channel2 = channelHandlerContext.channel();
                Objects.requireNonNull(channel2);
                logger2.debug("[{}] Received key exchange message with long time key. This is invalid and may be a sign for an MITM attack.", channel2::id);
            }
        });
    }

    private void receivedAck(ChannelHandlerContext channelHandlerContext, AgreementId agreementId) {
        if (Objects.equals(agreementId, ((PendingAgreement) this.session.getCurrentInactiveAgreement().computeOnCondition(pendingAgreement -> {
            return pendingAgreement != null && Objects.equals(agreementId, pendingAgreement.getAgreementId());
        }, pendingAgreement2 -> {
            try {
                Logger logger = LOG;
                Channel channel = channelHandlerContext.channel();
                Objects.requireNonNull(channel);
                logger.trace("[{}] Received ack message", channel::id);
                Agreement buildAgreement = pendingAgreement2.buildAgreement(this.crypto, this.expireProvider.getAsLong());
                this.session.getInitializedAgreements().put(agreementId, buildAgreement);
                this.session.getCurrentActiveAgreement().computeOnCondition(agreement -> {
                    return true;
                }, agreement2 -> {
                    return buildAgreement;
                });
                this.session.setLastRenewAttemptAt(System.currentTimeMillis());
                if (this.state != State.LONG_TIME) {
                    return null;
                }
                channelHandlerContext.fireUserEventTriggered(PerfectForwardSecrecyEncryptionEvent.of(Peer.of(this.peerIdentity)));
                this.state = State.PFS;
                return null;
            } catch (CryptoException e) {
                LOG.debug("Can't compute new agreement: ", e);
                return pendingAgreement2;
            }
        }).orElse(computeInactiveAgreementIfNeeded())).getAgreementId())) {
            return;
        }
        doKeyExchange(channelHandlerContext);
    }

    private void sendAck(ChannelHandlerContext channelHandlerContext) {
        PendingAgreement pendingAgreement = (PendingAgreement) this.session.getCurrentInactiveAgreement().getValue().orElse(null);
        if (pendingAgreement != null) {
            try {
                ByteBuf buffer = channelHandlerContext.alloc().buffer();
                AcknowledgementMessage.of(pendingAgreement.getAgreementId()).writeTo(buffer);
                channelHandlerContext.writeAndFlush(arm(channelHandlerContext, this.session.getLongTimeAgreement(), buffer));
                buffer.release();
                Logger logger = LOG;
                Channel channel = channelHandlerContext.channel();
                Objects.requireNonNull(channel);
                Supplier supplier = channel::id;
                Objects.requireNonNull(pendingAgreement);
                logger.trace("[{}] Send ack message for session {}", supplier, pendingAgreement::getAgreementId);
            } catch (CryptoException e) {
                Logger logger2 = LOG;
                Channel channel2 = channelHandlerContext.channel();
                Objects.requireNonNull(channel2);
                Objects.requireNonNull(pendingAgreement);
                Objects.requireNonNull(e);
                logger2.trace("[{}] Error on sending ack message for session {}: {}", new Supplier[]{channel2::id, pendingAgreement::getAgreementId, e::toString});
            }
        }
    }

    private void doKeyExchange(ChannelHandlerContext channelHandlerContext) {
        PendingAgreement computeInactiveAgreementIfNeeded = computeInactiveAgreementIfNeeded();
        if (this.session.getLastKeyExchangeAt() < System.currentTimeMillis() - this.retryInterval.toMillis()) {
            Logger logger = LOG;
            Channel channel = channelHandlerContext.channel();
            Objects.requireNonNull(channel);
            logger.trace("[{}] Send key exchange message, do to key exchange overdue", channel::id);
            try {
                ByteBuf buffer = channelHandlerContext.alloc().buffer();
                KeyExchangeMessage.of(computeInactiveAgreementIfNeeded.getKeyPair().getPublicKey()).writeTo(buffer);
                channelHandlerContext.writeAndFlush(arm(channelHandlerContext, this.session.getLongTimeAgreement(), buffer));
                buffer.release();
            } catch (CryptoException e) {
                LOG.debug("Can't arm key exchange message: ", e);
            }
            this.session.setLastKeyExchangeAt(System.currentTimeMillis());
        }
    }

    private PendingAgreement computeInactiveAgreementIfNeeded() {
        return (PendingAgreement) this.session.getCurrentInactiveAgreement().computeIfAbsent(() -> {
            try {
                return new PendingAgreement(this.crypto.generateEphemeralKeyPair());
            } catch (CryptoException e) {
                LOG.debug("Could not generate ephemeral key: ", e);
                return null;
            }
        });
    }

    @Override // org.drasyl.node.handler.crypto.AbstractArmHandler
    protected void removeStaleAgreement(ChannelHandlerContext channelHandlerContext, Agreement agreement) {
        if (agreement.isStale()) {
            this.session.getInitializedAgreements().remove(agreement.getAgreementId());
            this.session.getCurrentActiveAgreement().computeOnCondition(agreement2 -> {
                return agreement2 != null && agreement2.getAgreementId().equals(agreement.getAgreementId());
            }, agreement3 -> {
                if (this.state != State.PFS) {
                    return null;
                }
                channelHandlerContext.fireUserEventTriggered(LongTimeEncryptionEvent.of(Peer.of(this.peerIdentity)));
                this.state = State.LONG_TIME;
                return null;
            });
        }
    }

    @Override // org.drasyl.node.handler.crypto.AbstractArmHandler
    protected Agreement getAgreement(AgreementId agreementId) {
        return Objects.equals(agreementId, this.session.getLongTimeAgreement().getAgreementId()) ? this.session.getLongTimeAgreement() : this.session.getInitializedAgreements().get(agreementId);
    }

    @Override // org.drasyl.node.handler.crypto.AbstractArmHandler
    protected /* bridge */ /* synthetic */ void decode(ChannelHandlerContext channelHandlerContext, Object obj, List list) throws Exception {
        decode(channelHandlerContext, (ArmHeader) obj, (List<Object>) list);
    }

    @Override // org.drasyl.node.handler.crypto.AbstractArmHandler
    protected /* bridge */ /* synthetic */ void encode(ChannelHandlerContext channelHandlerContext, Object obj, List list) throws Exception {
        encode(channelHandlerContext, (ByteBuf) obj, (List<Object>) list);
    }
}
