package org.nd4j.parameterserver.distributed.v2.transport.impl;

import io.aeron.Aeron;
import io.aeron.FragmentAssembler;
import io.aeron.Publication;
import io.aeron.Subscription;
import io.aeron.driver.MediaDriver;
import io.aeron.logbuffer.Header;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;
import lombok.NonNull;
import org.agrona.DirectBuffer;
import org.agrona.concurrent.SleepingIdleStrategy;
import org.agrona.concurrent.UnsafeBuffer;
import org.nd4j.aeron.ipc.AeronUtil;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.HashUtil;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk;
import org.nd4j.parameterserver.distributed.v2.enums.PropagationMode;
import org.nd4j.parameterserver.distributed.v2.enums.TransmissionStatus;
import org.nd4j.parameterserver.distributed.v2.messages.INDArrayMessage;
import org.nd4j.parameterserver.distributed.v2.messages.RequestMessage;
import org.nd4j.parameterserver.distributed.v2.messages.VoidMessage;
import org.nd4j.parameterserver.distributed.v2.transport.MessageCallable;
import org.nd4j.parameterserver.distributed.v2.util.MeshOrganizer;
import org.nd4j.parameterserver.distributed.v2.util.MessageSplitter;
import org.nd4j.shade.guava.math.IntMath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.class */
public class AeronUdpTransport extends BaseTransport implements AutoCloseable {
    protected Map<String, MessageCallable> interceptors;
    protected Map<String, MessageCallable> precursors;
    protected Map<String, RemoteConnection> remoteConnections;
    protected final int SENDER_THREADS = 2;
    protected final int MESSAGE_THREADS = 2;
    protected final int SUBSCRIPTION_THREADS = 1;
    protected Aeron aeron;
    protected Aeron.Context context;
    protected Subscription ownSubscription;
    protected FragmentAssembler messageHandler;
    protected Thread subscriptionThread;
    protected MediaDriver driver;
    protected BlockingQueue<VoidMessage> messageQueue;
    protected BlockingQueue<INDArrayMessage> propagationQueue;
    protected ReentrantLock aeronLock;
    protected final AtomicBoolean shutdownFlag;
    protected final AtomicBoolean connectedFlag;
    protected ExecutorService messagesExecutorService;
    private static final Logger log = LoggerFactory.getLogger(AeronUdpTransport.class);
    private static final long DEFAULT_TERM_BUFFER_PROP = IntMath.pow(2, 25);

    /* renamed from: org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport$5, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport$5.class */
    static /* synthetic */ class AnonymousClass5 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$parameterserver$distributed$v2$enums$TransmissionStatus = new int[TransmissionStatus.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$parameterserver$distributed$v2$enums$TransmissionStatus[TransmissionStatus.MAX_POSITION_EXCEEDED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$parameterserver$distributed$v2$enums$TransmissionStatus[TransmissionStatus.CLOSED.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$parameterserver$distributed$v2$enums$TransmissionStatus[TransmissionStatus.ADMIN_ACTION.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$parameterserver$distributed$v2$enums$TransmissionStatus[TransmissionStatus.NOT_CONNECTED.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$parameterserver$distributed$v2$enums$TransmissionStatus[TransmissionStatus.BACK_PRESSURED.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    /* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport$RemoteConnection.class */
    public static class RemoteConnection {
        private String ip;
        private int port;
        private Publication publication;
        private final Object locker = new Object();
        private final AtomicBoolean activated = new AtomicBoolean(false);
        protected long longHash;

        /* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport$RemoteConnection$RemoteConnectionBuilder.class */
        public static class RemoteConnectionBuilder {
            private String ip;
            private int port;
            private Publication publication;
            private long longHash;

            RemoteConnectionBuilder() {
            }

            public RemoteConnectionBuilder ip(String str) {
                this.ip = str;
                return this;
            }

            public RemoteConnectionBuilder port(int i) {
                this.port = i;
                return this;
            }

            public RemoteConnectionBuilder publication(Publication publication) {
                this.publication = publication;
                return this;
            }

            public RemoteConnectionBuilder longHash(long j) {
                this.longHash = j;
                return this;
            }

            public RemoteConnection build() {
                return new RemoteConnection(this.ip, this.port, this.publication, this.longHash);
            }

            public String toString() {
                return "AeronUdpTransport.RemoteConnection.RemoteConnectionBuilder(ip=" + this.ip + ", port=" + this.port + ", publication=" + this.publication + ", longHash=" + this.longHash + ")";
            }
        }

        RemoteConnection(String str, int i, Publication publication, long j) {
            this.ip = str;
            this.port = i;
            this.publication = publication;
            this.longHash = j;
        }

        public static RemoteConnectionBuilder builder() {
            return new RemoteConnectionBuilder();
        }

        public String getIp() {
            return this.ip;
        }

        public int getPort() {
            return this.port;
        }

        public Publication getPublication() {
            return this.publication;
        }

        public Object getLocker() {
            return this.locker;
        }

        public AtomicBoolean getActivated() {
            return this.activated;
        }

        public long getLongHash() {
            return this.longHash;
        }

        public void setIp(String str) {
            this.ip = str;
        }

        public void setPort(int i) {
            this.port = i;
        }

        public void setPublication(Publication publication) {
            this.publication = publication;
        }

        public void setLongHash(long j) {
            this.longHash = j;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof RemoteConnection)) {
                return false;
            }
            RemoteConnection remoteConnection = (RemoteConnection) obj;
            if (!remoteConnection.canEqual(this) || getPort() != remoteConnection.getPort() || getLongHash() != remoteConnection.getLongHash()) {
                return false;
            }
            String ip = getIp();
            String ip2 = remoteConnection.getIp();
            if (ip == null) {
                if (ip2 != null) {
                    return false;
                }
            } else if (!ip.equals(ip2)) {
                return false;
            }
            Publication publication = getPublication();
            Publication publication2 = remoteConnection.getPublication();
            if (publication == null) {
                if (publication2 != null) {
                    return false;
                }
            } else if (!publication.equals(publication2)) {
                return false;
            }
            Object locker = getLocker();
            Object locker2 = remoteConnection.getLocker();
            if (locker == null) {
                if (locker2 != null) {
                    return false;
                }
            } else if (!locker.equals(locker2)) {
                return false;
            }
            AtomicBoolean activated = getActivated();
            AtomicBoolean activated2 = remoteConnection.getActivated();
            return activated == null ? activated2 == null : activated.equals(activated2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof RemoteConnection;
        }

        public int hashCode() {
            int port = (1 * 59) + getPort();
            long longHash = getLongHash();
            int i = (port * 59) + ((int) ((longHash >>> 32) ^ longHash));
            String ip = getIp();
            int hashCode = (i * 59) + (ip == null ? 43 : ip.hashCode());
            Publication publication = getPublication();
            int hashCode2 = (hashCode * 59) + (publication == null ? 43 : publication.hashCode());
            Object locker = getLocker();
            int hashCode3 = (hashCode2 * 59) + (locker == null ? 43 : locker.hashCode());
            AtomicBoolean activated = getActivated();
            return (hashCode3 * 59) + (activated == null ? 43 : activated.hashCode());
        }

        public String toString() {
            return "AeronUdpTransport.RemoteConnection(ip=" + getIp() + ", port=" + getPort() + ", publication=" + getPublication() + ", locker=" + getLocker() + ", activated=" + getActivated() + ", longHash=" + getLongHash() + ")";
        }
    }

    public AeronUdpTransport(@NonNull String str, @NonNull String str2, @NonNull VoidConfiguration voidConfiguration) {
        this(str, voidConfiguration.getPortSupplier().getPort(), str2, voidConfiguration.getUnicastControllerPort(), voidConfiguration);
        if (str == null) {
            throw new NullPointerException("ownIp is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("rootIp is marked non-null but is null");
        }
        if (voidConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
    }

    public AeronUdpTransport(@NonNull String str, int i, @NonNull VoidConfiguration voidConfiguration) {
        this(str, i, str, i, voidConfiguration);
        if (str == null) {
            throw new NullPointerException("rootIp is marked non-null but is null");
        }
        if (voidConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
    }

    public AeronUdpTransport(@NonNull String str, int i, @NonNull String str2, int i2, @NonNull VoidConfiguration voidConfiguration) {
        super("aeron:udp?endpoint=" + str + ":" + i, "aeron:udp?endpoint=" + str2 + ":" + i2, voidConfiguration);
        this.interceptors = new HashMap();
        this.precursors = new HashMap();
        this.remoteConnections = new ConcurrentHashMap();
        this.SENDER_THREADS = 2;
        this.MESSAGE_THREADS = 2;
        this.SUBSCRIPTION_THREADS = 1;
        this.messageQueue = new LinkedTransferQueue();
        this.propagationQueue = new LinkedBlockingQueue(32);
        this.aeronLock = new ReentrantLock();
        this.shutdownFlag = new AtomicBoolean(false);
        this.connectedFlag = new AtomicBoolean(false);
        this.messagesExecutorService = Executors.newFixedThreadPool(5, new ThreadFactory() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport.1
            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(@NonNull final Runnable runnable) {
                if (runnable == null) {
                    throw new NullPointerException("r is marked non-null but is null");
                }
                Thread thread = new Thread(new Runnable() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport.1.1
                    @Override // java.lang.Runnable
                    public void run() {
                        Nd4j.getAffinityManager().unsafeSetDevice(0);
                        runnable.run();
                    }
                });
                thread.setDaemon(true);
                thread.setName("MessagesExecutorService thread");
                return thread;
            }
        });
        if (str == null) {
            throw new NullPointerException("ownIp is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("rootIp is marked non-null but is null");
        }
        if (voidConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        Preconditions.checkArgument(i > 0 && i < 65536, "Own UDP port should be positive value in range of 1 and 65536");
        Preconditions.checkArgument(i2 > 0 && i2 < 65536, "Master node UDP port should be positive value in range of 1 and 65536");
        System.setProperty("aeron.client.liveness.timeout", "30000000000");
        if (System.getProperty("aeron.term.buffer.length") == null) {
            System.setProperty("aeron.term.buffer.length", String.valueOf(DEFAULT_TERM_BUFFER_PROP));
        }
        this.splitter = MessageSplitter.getInstance();
        this.context = new Aeron.Context().driverTimeoutMs(30000L).keepAliveIntervalNs(100000000L);
        AeronUtil.setDaemonizedThreadFactories(this.context);
        MediaDriver.Context context = new MediaDriver.Context();
        AeronUtil.setDaemonizedThreadFactories(context);
        this.driver = MediaDriver.launchEmbedded(context);
        this.context.aeronDirectoryName(this.driver.aeronDirectoryName());
        this.aeron = Aeron.connect(this.context);
        Runtime.getRuntime().addShutdownHook(new Thread(() -> {
            shutdown();
        }));
    }

    protected void createSubscription() {
        this.ownSubscription = this.aeron.addSubscription(id(), this.voidConfiguration.getStreamId());
        this.messageHandler = new FragmentAssembler((directBuffer, i, i2, header) -> {
            jointMessageHandler(directBuffer, i, i2, header);
        });
        for (int i3 = 0; i3 < 1; i3++) {
            this.messagesExecutorService.execute(new Runnable() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport.2
                @Override // java.lang.Runnable
                public void run() {
                    while (true) {
                        new SleepingIdleStrategy(1000L).idle(AeronUdpTransport.this.ownSubscription.poll(AeronUdpTransport.this.messageHandler, 1024));
                    }
                }
            });
        }
        for (int i4 = 0; i4 < 2; i4++) {
            this.messagesExecutorService.execute(new Runnable() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport.3
                @Override // java.lang.Runnable
                public void run() {
                    while (true) {
                        try {
                            AeronUdpTransport.this.processMessage(AeronUdpTransport.this.messageQueue.take());
                        } catch (InterruptedException e) {
                            return;
                        }
                    }
                }
            });
        }
        for (int i5 = 0; i5 < 2; i5++) {
            this.messagesExecutorService.execute(new Runnable() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport.4
                @Override // java.lang.Runnable
                public void run() {
                    while (true) {
                        try {
                            AeronUdpTransport.this.redirectedPropagateArrayMessage(AeronUdpTransport.this.propagationQueue.take());
                        } catch (IOException e) {
                            AeronUdpTransport.log.error("Exception: {}", e);
                            throw new RuntimeException(e);
                        } catch (InterruptedException e2) {
                            return;
                        }
                    }
                }
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport
    public void propagateArrayMessage(INDArrayMessage iNDArrayMessage, PropagationMode propagationMode) throws IOException {
        try {
            this.propagationQueue.put(iNDArrayMessage);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    protected void redirectedPropagateArrayMessage(INDArrayMessage iNDArrayMessage) throws IOException {
        super.propagateArrayMessage(iNDArrayMessage, PropagationMode.BOTH_WAYS);
    }

    protected void jointMessageHandler(DirectBuffer directBuffer, int i, int i2, Header header) {
        byte[] bArr = new byte[i2];
        directBuffer.getBytes(i, bArr);
        VoidMessage fromBytes = VoidMessage.fromBytes(bArr);
        if (!this.remoteConnections.containsKey(fromBytes.getOriginatorId())) {
            addConnection(fromBytes.getOriginatorId());
        }
        log.debug("Got [{}] message from [{}]", fromBytes.getClass().getSimpleName(), fromBytes.getOriginatorId());
        try {
            this.messageQueue.put(fromBytes);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport, org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void onRemap(String str) {
        try {
            this.aeronLock.lock();
            log.info("Trying to disconnect failed node: [{}]", str);
            if (this.remoteConnections.containsKey(str)) {
                try {
                    this.remoteConnections.get(str).getPublication().close();
                } catch (Exception e) {
                }
                this.remoteConnections.remove(str);
            }
            log.info("Trying to add failed node back again: [{}]", str);
            addConnection(str);
            this.aeronLock.unlock();
        } catch (Throwable th) {
            this.aeronLock.unlock();
            throw th;
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport, org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void ensureConnection(String str) {
        addConnection(str);
    }

    protected void addConnection(@NonNull String str) {
        int i;
        if (str == null) {
            throw new NullPointerException("ipAndPort is marked non-null but is null");
        }
        try {
            this.aeronLock.lock();
            if (this.remoteConnections.containsKey(str)) {
                return;
            }
            log.info("Adding UDP connection: [{}]", str);
            Publication addPublication = this.aeron.addPublication(str, this.voidConfiguration.getStreamId());
            int i2 = 0;
            while (!addPublication.isConnected()) {
                try {
                    Thread.sleep(100L);
                    i = i2;
                    i2++;
                } catch (InterruptedException e) {
                }
                if (i > 100) {
                    throw new ND4JIllegalStateException("Can't establish connection afet 10 seconds. Terminating...");
                    break;
                }
            }
            this.remoteConnections.put(str, RemoteConnection.builder().ip(str).port(0).longHash(HashUtil.getLongHash(str)).publication(addPublication).build());
            this.aeronLock.unlock();
        } finally {
            this.aeronLock.unlock();
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        shutdown();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport, org.nd4j.parameterserver.distributed.v2.transport.Transport
    public synchronized void launch() {
        if (!this.masterMode) {
            addConnection(this.rootId);
            createSubscription();
        }
        super.launch();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport, org.nd4j.parameterserver.distributed.v2.transport.Transport
    public synchronized void launchAsMaster() {
        createSubscription();
        super.launchAsMaster();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public String id() {
        return this.id;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport, org.nd4j.parameterserver.distributed.v2.transport.Transport
    public boolean isConnected() {
        if (this.connectedFlag.get() || this.masterMode) {
            return true;
        }
        if (!this.remoteConnections.containsKey(this.rootId)) {
            return false;
        }
        synchronized (this.mesh) {
            if (!this.remoteConnections.containsKey(((MeshOrganizer) this.mesh.get()).getUpstreamForNode(id()).getId())) {
                return false;
            }
            Iterator<MeshOrganizer.Node> it = ((MeshOrganizer) this.mesh.get()).getDownstreamsForNode(id()).iterator();
            while (it.hasNext()) {
                if (!this.remoteConnections.containsKey(it.next().getId())) {
                    return false;
                }
            }
            this.connectedFlag.set(true);
            return true;
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void sendMessage(@NonNull VoidMessage voidMessage, @NonNull String str) {
        if (voidMessage == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("id is marked non-null but is null");
        }
        if (voidMessage.getOriginatorId() == null) {
            voidMessage.setOriginatorId(id());
        }
        if ((voidMessage instanceof RequestMessage) && ((RequestMessage) voidMessage).getRequestId() == null) {
            ((RequestMessage) voidMessage).setRequestId(UUID.randomUUID().toString());
        }
        if (voidMessage.getOriginatorId().equals(str)) {
            processMessage(voidMessage);
            return;
        }
        if (voidMessage instanceof INDArrayMessage) {
            try {
                Iterator<VoidChunk> it = this.splitter.split(voidMessage, this.voidConfiguration.getMaxChunkSize()).iterator();
                while (it.hasNext()) {
                    sendMessage(it.next(), str);
                }
                return;
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        UnsafeBuffer asUnsafeBuffer = voidMessage.asUnsafeBuffer();
        if (!str.equals(this.rootId)) {
            while (!isConnected()) {
                LockSupport.parkNanos(10000000L);
            }
        }
        RemoteConnection remoteConnection = this.remoteConnections.get(str);
        if (remoteConnection == null) {
            throw new ND4JIllegalStateException("Unknown target ID specified: [" + str + "]");
        }
        TransmissionStatus transmissionStatus = TransmissionStatus.UNKNOWN;
        while (transmissionStatus != TransmissionStatus.OK) {
            synchronized (remoteConnection.locker) {
                transmissionStatus = TransmissionStatus.fromLong(remoteConnection.getPublication().offer(asUnsafeBuffer));
            }
            switch (AnonymousClass5.$SwitchMap$org$nd4j$parameterserver$distributed$v2$enums$TransmissionStatus[transmissionStatus.ordinal()]) {
                case 1:
                    log.warn("MaxPosition hit: [{}]", str);
                    try {
                        Thread.sleep(this.voidConfiguration.getRetransmitTimeout());
                        break;
                    } catch (InterruptedException e2) {
                        break;
                    }
                case 2:
                    log.warn(" Connection was closed: [{}]", str);
                    return;
                case 3:
                    log.info("ADMIN_ACTION: [{}]", str);
                    try {
                        Thread.sleep(this.voidConfiguration.getRetransmitTimeout());
                        break;
                    } catch (InterruptedException e3) {
                        break;
                    }
                case 4:
                    log.info("NOT_CONNECTED: [{}]", str);
                    addConnection(str);
                    try {
                        Thread.sleep(this.voidConfiguration.getRetransmitTimeout());
                        break;
                    } catch (InterruptedException e4) {
                        break;
                    }
                case MeshOrganizer.MAX_DEPTH /* 5 */:
                    log.info("BACK_PRESSURED: [{}]", str);
                    try {
                        Thread.sleep(this.voidConfiguration.getRetransmitTimeout());
                        break;
                    } catch (InterruptedException e5) {
                        break;
                    }
            }
        }
    }

    protected void shutdownSilent() {
        this.ownSubscription.close();
        Iterator<RemoteConnection> it = this.remoteConnections.values().iterator();
        while (it.hasNext()) {
            it.next().getPublication().close();
        }
        this.messagesExecutorService.shutdown();
        this.aeron.close();
        this.context.close();
        this.driver.close();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport, org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void shutdown() {
        if (this.shutdownFlag.compareAndSet(false, true)) {
            shutdownSilent();
            super.shutdown();
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport, org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void onMeshUpdate(MeshOrganizer meshOrganizer) {
        meshOrganizer.flatNodes().forEach(node -> {
            addConnection(node.getId());
        });
        super.onMeshUpdate(meshOrganizer);
    }

    public <T extends VoidMessage> void addInterceptor(@NonNull Class<T> cls, @NonNull MessageCallable<T> messageCallable) {
        if (cls == null) {
            throw new NullPointerException("cls is marked non-null but is null");
        }
        if (messageCallable == null) {
            throw new NullPointerException("callable is marked non-null but is null");
        }
        this.interceptors.put(cls.getCanonicalName(), messageCallable);
    }

    public <T extends VoidMessage> void addPrecursor(@NonNull Class<T> cls, @NonNull MessageCallable<T> messageCallable) {
        if (cls == null) {
            throw new NullPointerException("cls is marked non-null but is null");
        }
        if (messageCallable == null) {
            throw new NullPointerException("callable is marked non-null but is null");
        }
        this.precursors.put(cls.getCanonicalName(), messageCallable);
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport, org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void processMessage(@NonNull VoidMessage voidMessage) {
        if (voidMessage == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        if (this.interceptors.isEmpty() && this.precursors.isEmpty()) {
            super.processMessage(voidMessage);
            return;
        }
        String canonicalName = voidMessage.getClass().getCanonicalName();
        MessageCallable messageCallable = this.interceptors.get(canonicalName);
        if (messageCallable != null) {
            messageCallable.apply(voidMessage);
            return;
        }
        MessageCallable messageCallable2 = this.precursors.get(canonicalName);
        if (messageCallable2 != null) {
            messageCallable2.apply(voidMessage);
        }
        super.processMessage(voidMessage);
    }

    protected MeshOrganizer getMesh() {
        MeshOrganizer meshOrganizer;
        synchronized (this.mesh) {
            meshOrganizer = (MeshOrganizer) this.mesh.get();
        }
        return meshOrganizer;
    }
}
