package org.nd4j.parameterserver.distributed.v2.transport.impl;

import io.reactivex.Flowable;
import io.reactivex.functions.Consumer;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TransferQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.nd4j.common.primitives.Atomic;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Optional;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.v2.chunks.VoidChunk;
import org.nd4j.parameterserver.distributed.v2.enums.MeshBuildMode;
import org.nd4j.parameterserver.distributed.v2.enums.PropagationMode;
import org.nd4j.parameterserver.distributed.v2.messages.BroadcastableMessage;
import org.nd4j.parameterserver.distributed.v2.messages.INDArrayMessage;
import org.nd4j.parameterserver.distributed.v2.messages.MessagesHistoryHolder;
import org.nd4j.parameterserver.distributed.v2.messages.RequestMessage;
import org.nd4j.parameterserver.distributed.v2.messages.ResponseMessage;
import org.nd4j.parameterserver.distributed.v2.messages.VoidMessage;
import org.nd4j.parameterserver.distributed.v2.messages.history.HashHistoryHolder;
import org.nd4j.parameterserver.distributed.v2.messages.impl.MeshUpdateMessage;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.handshake.HandshakeRequest;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.handshake.HandshakeResponse;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.ping.PingMessage;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.ping.PongMessage;
import org.nd4j.parameterserver.distributed.v2.transport.RestartCallback;
import org.nd4j.parameterserver.distributed.v2.transport.Transport;
import org.nd4j.parameterserver.distributed.v2.util.MeshOrganizer;
import org.nd4j.parameterserver.distributed.v2.util.MessageSplitter;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.class */
public abstract class BaseTransport implements Transport {
    private static final Logger log = LoggerFactory.getLogger(BaseTransport.class);
    protected final MessageFlow<VoidMessage> outgoingFlow;
    protected final MessageFlow<INDArrayMessage> incomingFlow;
    protected final Atomic<MeshOrganizer> mesh;
    protected String id;
    protected String rootId;
    protected boolean masterMode;
    protected final Map<String, ResponseMessage> replies;
    protected RestartCallback restartCallback;
    protected Map<String, Consumer> consumers;
    protected final VoidConfiguration voidConfiguration;
    protected final MeshBuildMode meshBuildMode;
    protected final AtomicInteger numerOfNodes;
    protected final TransferQueue<VoidMessage> messageQueue;
    protected MessageSplitter splitter;
    protected MessagesHistoryHolder<String> historyHolder;
    protected AtomicBoolean handshakeFlag;
    protected final ThreadPoolExecutor executorService;

    /* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport$HeartbeatThread.class */
    protected static class HeartbeatThread extends Thread implements Runnable {
        protected final long delay;
        protected final Atomic<MeshOrganizer> mesh;
        protected final Transport transport;

        protected HeartbeatThread(long j, @NonNull Transport transport, @NonNull Atomic<MeshOrganizer> atomic) {
            if (transport == null) {
                throw new NullPointerException("transport is marked non-null but is null");
            }
            if (atomic == null) {
                throw new NullPointerException("mesh is marked non-null but is null");
            }
            this.delay = j;
            this.mesh = atomic;
            this.transport = transport;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    Thread.sleep(this.delay);
                    AtomicBoolean atomicBoolean = new AtomicBoolean(false);
                    for (MeshOrganizer.Node node : ((MeshOrganizer) this.mesh.get()).flatNodes()) {
                        if (!this.transport.id().equals(node.getId())) {
                            if (((PongMessage) this.transport.sendMessageBlocking(new PingMessage(), node.getId(), 100L, TimeUnit.MILLISECONDS)) == null) {
                                ((MeshOrganizer) this.mesh.get()).remapNode(node);
                                ((MeshOrganizer) this.mesh.get()).markNodeOffline(node);
                                atomicBoolean.set(true);
                            }
                        }
                    }
                    if (atomicBoolean.get()) {
                        try {
                            this.transport.propagateMessage(new MeshUpdateMessage((MeshOrganizer) this.mesh.get()), PropagationMode.ONLY_DOWN);
                        } catch (IOException e) {
                        }
                    }
                } catch (InterruptedException e2) {
                    return;
                }
            }
        }
    }

    /* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport$MessageFlow.class */
    public static class MessageFlow<T> implements Consumer<T>, Publisher<T> {
        private List<Subscriber<? super T>> subscribers = new CopyOnWriteArrayList();

        public void accept(T t) throws Exception {
            this.subscribers.forEach(subscriber -> {
                subscriber.onNext(t);
            });
        }

        public void subscribe(Subscriber<? super T> subscriber) {
            this.subscribers.add(subscriber);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTransport() {
        this(UUID.randomUUID().toString());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTransport(@NonNull String str) {
        this(str, VoidConfiguration.builder().build());
        if (str == null) {
            throw new NullPointerException("rootId is marked non-null but is null");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTransport(@NonNull String str, @NonNull VoidConfiguration voidConfiguration) {
        this.outgoingFlow = new MessageFlow<>();
        this.incomingFlow = new MessageFlow<>();
        this.mesh = new Atomic<>();
        this.masterMode = false;
        this.replies = new ConcurrentHashMap();
        this.consumers = new HashMap();
        this.meshBuildMode = MeshBuildMode.MESH;
        this.numerOfNodes = new AtomicInteger(0);
        this.messageQueue = new LinkedTransferQueue();
        this.historyHolder = new HashHistoryHolder(2048);
        this.handshakeFlag = new AtomicBoolean(false);
        this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport.1
            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(@NonNull final Runnable runnable) {
                if (runnable == null) {
                    throw new NullPointerException("r is marked non-null but is null");
                }
                Thread thread = new Thread(new Runnable() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport.1.1
                    @Override // java.lang.Runnable
                    public void run() {
                        Nd4j.getAffinityManager().unsafeSetDevice(0);
                        runnable.run();
                    }
                });
                thread.setDaemon(true);
                return thread;
            }
        });
        if (str == null) {
            throw new NullPointerException("rootId is marked non-null but is null");
        }
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked non-null but is null");
        }
        this.mesh.set(new MeshOrganizer(voidConfiguration.getMeshBuildMode()));
        this.rootId = str;
        this.voidConfiguration = voidConfiguration;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseTransport(@NonNull String str, @NonNull String str2, @NonNull VoidConfiguration voidConfiguration) {
        this.outgoingFlow = new MessageFlow<>();
        this.incomingFlow = new MessageFlow<>();
        this.mesh = new Atomic<>();
        this.masterMode = false;
        this.replies = new ConcurrentHashMap();
        this.consumers = new HashMap();
        this.meshBuildMode = MeshBuildMode.MESH;
        this.numerOfNodes = new AtomicInteger(0);
        this.messageQueue = new LinkedTransferQueue();
        this.historyHolder = new HashHistoryHolder(2048);
        this.handshakeFlag = new AtomicBoolean(false);
        this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport.1
            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(@NonNull final Runnable runnable) {
                if (runnable == null) {
                    throw new NullPointerException("r is marked non-null but is null");
                }
                Thread thread = new Thread(new Runnable() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport.1.1
                    @Override // java.lang.Runnable
                    public void run() {
                        Nd4j.getAffinityManager().unsafeSetDevice(0);
                        runnable.run();
                    }
                });
                thread.setDaemon(true);
                return thread;
            }
        });
        if (str == null) {
            throw new NullPointerException("ownId is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("rootId is marked non-null but is null");
        }
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked non-null but is null");
        }
        this.mesh.set(new MeshOrganizer(voidConfiguration.getMeshBuildMode()));
        this.id = str;
        this.rootId = str2;
        this.voidConfiguration = voidConfiguration;
        this.masterMode = str.equalsIgnoreCase(str2);
        if (this.masterMode) {
            ((MeshOrganizer) this.mesh.get()).getRootNode().setId(str2);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public Consumer<VoidMessage> outgoingConsumer() {
        return this.outgoingFlow;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public Publisher<INDArrayMessage> incomingPublisher() {
        return this.incomingFlow;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public String getUpstreamId() {
        return ((MeshOrganizer) this.mesh.get()).getRootNode().getId().equals(id()) ? id() : ((MeshOrganizer) this.mesh.get()).getNodeById(id()).getUpstreamNode().getId();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public synchronized void launch() {
        int i = this.masterMode ? 1 : 0;
        for (int i2 = 0; i2 < this.executorService.getMaximumPoolSize() - i; i2++) {
            this.executorService.submit(new Runnable() { // from class: org.nd4j.parameterserver.distributed.v2.transport.impl.BaseTransport.2
                @Override // java.lang.Runnable
                public void run() {
                    while (true) {
                        try {
                            VoidMessage take = BaseTransport.this.messageQueue.take();
                            if (take != null) {
                                BaseTransport.this.internalProcessMessage(take);
                            }
                        } catch (InterruptedException e) {
                            return;
                        } catch (Exception e2) {
                            BaseTransport.log.error("Exception: {}", e2);
                        }
                    }
                }
            });
        }
        Flowable.fromPublisher(this.outgoingFlow).subscribe(voidMessage -> {
            if (this.mesh.get() == null) {
                log.warn("Mesh wasn't received yet!");
            } else {
                voidMessage.setOriginatorId(this.id);
                propagateMessage(voidMessage, PropagationMode.BOTH_WAYS);
            }
        });
        if (this.masterMode) {
            return;
        }
        try {
            sendMessageBlocking(new HandshakeRequest(), this.rootId);
        } catch (Exception e) {
            throw new ND4JIllegalStateException("Can't proceed with handshake from [" + id() + "] to [" + this.rootId + "]", e);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public synchronized void launchAsMaster() {
        if (this.mesh.get() == null) {
            this.mesh.set(new MeshOrganizer(this.meshBuildMode));
        }
        this.masterMode = true;
        ((MeshOrganizer) this.mesh.get()).getRootNode().setId(id());
        this.executorService.submit(new HeartbeatThread(120000L, this, this.mesh));
        launch();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public synchronized void shutdown() {
        this.executorService.shutdown();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void propagateArrayMessage(INDArrayMessage iNDArrayMessage, PropagationMode propagationMode) throws IOException {
        MeshOrganizer.Node nodeById = ((MeshOrganizer) this.mesh.get()).getNodeById(this.id);
        ((MeshOrganizer) this.mesh.get()).getRootNode();
        MeshOrganizer.Node upstreamNode = nodeById.getUpstreamNode();
        Collection<MeshOrganizer.Node> downstreamNodes = nodeById.getDownstreamNodes();
        Collection<VoidChunk> split = this.splitter.split(iNDArrayMessage, this.voidConfiguration.getMaxChunkSize());
        if (!nodeById.isRootNode() && (PropagationMode.BOTH_WAYS == propagationMode || PropagationMode.ONLY_UP == propagationMode)) {
            split.forEach(voidChunk -> {
                sendMessage(voidChunk, upstreamNode.getId());
            });
        }
        if (PropagationMode.BOTH_WAYS == propagationMode || PropagationMode.ONLY_DOWN == propagationMode) {
            downstreamNodes.parallelStream().forEach(node -> {
                split.forEach(voidChunk2 -> {
                    sendMessage(voidChunk2, node.getId());
                });
            });
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void propagateMessage(@NonNull VoidMessage voidMessage, PropagationMode propagationMode) throws IOException {
        if (voidMessage == null) {
            throw new NullPointerException("voidMessage is marked non-null but is null");
        }
        MeshOrganizer.Node nodeById = ((MeshOrganizer) this.mesh.get()).getNodeById(this.id);
        if (((MeshOrganizer) this.mesh.get()).totalNodes() == 1) {
            internalProcessMessage(voidMessage);
            return;
        }
        ((MeshOrganizer) this.mesh.get()).getRootNode();
        MeshOrganizer.Node upstreamNode = nodeById.getUpstreamNode();
        Collection<MeshOrganizer.Node> downstreamNodes = nodeById.getDownstreamNodes();
        if (voidMessage instanceof BroadcastableMessage) {
            ((BroadcastableMessage) voidMessage).setRelayId(this.id);
        }
        if (voidMessage instanceof INDArrayMessage) {
            propagateArrayMessage((INDArrayMessage) voidMessage, propagationMode);
            return;
        }
        if (!nodeById.isRootNode() && (PropagationMode.BOTH_WAYS == propagationMode || PropagationMode.ONLY_UP == propagationMode)) {
            sendMessage(voidMessage, upstreamNode.getId());
        }
        if (PropagationMode.BOTH_WAYS == propagationMode || PropagationMode.ONLY_DOWN == propagationMode) {
            downstreamNodes.forEach(node -> {
                sendMessage(voidMessage, node.getId());
            });
        }
    }

    protected void propagateBroadcastableMessage(@NonNull BroadcastableMessage broadcastableMessage, PropagationMode propagationMode) {
        if (broadcastableMessage == null) {
            throw new NullPointerException("voidMessage is marked non-null but is null");
        }
        if ((broadcastableMessage instanceof MeshUpdateMessage) || this.historyHolder.storeIfUnknownMessageId(broadcastableMessage.getMessageId())) {
            return;
        }
        MeshOrganizer.Node nodeById = ((MeshOrganizer) this.mesh.get()).getNodeById(this.id);
        if (broadcastableMessage.getOriginatorId() == null || this.id == null || !broadcastableMessage.getOriginatorId().equals(this.id)) {
            ((MeshOrganizer) this.mesh.get()).getRootNode();
            MeshOrganizer.Node upstreamNode = nodeById.getUpstreamNode();
            Collection<MeshOrganizer.Node> downstreamNodes = nodeById.getDownstreamNodes();
            id();
            String id = nodeById.isRootNode() ? null : upstreamNode.getId();
            String originatorId = broadcastableMessage.getOriginatorId();
            String relayId = broadcastableMessage.getRelayId();
            broadcastableMessage.setRelayId(id());
            if (!nodeById.isRootNode() && ((PropagationMode.BOTH_WAYS == propagationMode || PropagationMode.ONLY_UP == propagationMode) && !isLoopedNode(upstreamNode, originatorId, relayId) && !isLoopedNode(upstreamNode, originatorId, relayId))) {
                sendMessage(broadcastableMessage, id);
            }
            if (PropagationMode.BOTH_WAYS == propagationMode || PropagationMode.ONLY_DOWN == propagationMode) {
                for (MeshOrganizer.Node node : downstreamNodes) {
                    if (!isLoopedNode(node, originatorId, relayId)) {
                        sendMessage(broadcastableMessage, node.getId());
                    }
                }
            }
        }
    }

    protected boolean isLoopedNode(@NonNull MeshOrganizer.Node node, @NonNull String str, @NonNull String str2) {
        if (node == null) {
            throw new NullPointerException("node is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("originatorId is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("relayId is marked non-null but is null");
        }
        return node.getId().equals(str) || node.getId().equals(str2);
    }

    private void forwardToParameterServer(INDArrayMessage iNDArrayMessage) {
        try {
            this.incomingFlow.accept(iNDArrayMessage);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void internalProcessMessage(VoidMessage voidMessage) {
        Consumer consumer;
        boolean z = voidMessage instanceof INDArrayMessage;
        if (voidMessage instanceof PingMessage) {
            PongMessage pongMessage = new PongMessage();
            pongMessage.setRequestId(((PingMessage) voidMessage).getRequestId());
            sendMessage(pongMessage, voidMessage.getOriginatorId());
            return;
        }
        if (!(voidMessage instanceof PongMessage)) {
            if (voidMessage instanceof VoidChunk) {
                Optional merge = this.splitter.merge((VoidChunk) voidMessage, this.voidConfiguration.getChunksBufferSize());
                if (merge.isPresent()) {
                    internalProcessMessage((VoidMessage) merge.get());
                }
            } else if (voidMessage instanceof INDArrayMessage) {
                if (voidMessage instanceof ResponseMessage) {
                    ResponseMessage responseMessage = (ResponseMessage) voidMessage;
                    this.replies.putIfAbsent(responseMessage.getRequestId(), responseMessage);
                } else if (!this.historyHolder.isKnownMessageId(voidMessage.getMessageId())) {
                    forwardToParameterServer((INDArrayMessage) voidMessage);
                }
            } else if (voidMessage instanceof HandshakeRequest) {
                synchronized (this.mesh) {
                    if (!((MeshOrganizer) this.mesh.get()).isKnownNode(id())) {
                        ((MeshOrganizer) this.mesh.get()).getRootNode().setId(this.id);
                    }
                }
                HandshakeResponse build = HandshakeResponse.builder().build();
                synchronized (this.mesh) {
                    if (((MeshOrganizer) this.mesh.get()).isKnownNode(voidMessage.getOriginatorId())) {
                        log.warn("Got request from known node [{}]. Remapping.", voidMessage.getOriginatorId());
                        onRemap(voidMessage.getOriginatorId());
                        ((MeshOrganizer) this.mesh.get()).remapNodeAndDownstreams(voidMessage.getOriginatorId());
                        build.setRestart(true);
                    } else {
                        ((MeshOrganizer) this.mesh.get()).addNode(voidMessage.getOriginatorId());
                        this.numerOfNodes.incrementAndGet();
                    }
                    build.setMesh(((MeshOrganizer) this.mesh.get()).m56clone());
                }
                build.setRequestId(((HandshakeRequest) voidMessage).getRequestId());
                sendMessage(build, voidMessage.getOriginatorId());
                try {
                    propagateMessageDirect(new MeshUpdateMessage((MeshOrganizer) this.mesh.get()));
                } catch (Exception e) {
                    log.error("Wasn't able to propagate message from [{}]", id());
                    log.error("MeshUpdateMessage propagation failed:", e);
                    throw new RuntimeException(e);
                }
            } else if (voidMessage instanceof HandshakeResponse) {
                HandshakeResponse handshakeResponse = (HandshakeResponse) voidMessage;
                MeshOrganizer mesh = handshakeResponse.getMesh();
                this.mesh.cas((Serializable) null, handshakeResponse.getMesh());
                synchronized (this.mesh) {
                    if (((MeshOrganizer) this.mesh.get()).getVersion() < mesh.getVersion()) {
                        this.mesh.set(mesh);
                    }
                }
                if (handshakeResponse.isRestart()) {
                    log.info("Processing restart response...");
                    if (this.restartCallback != null) {
                        this.restartCallback.call(handshakeResponse);
                    } else {
                        log.warn("Got restart message from master, but there's no defined RestartCallback");
                    }
                }
                this.handshakeFlag.set(true);
                ResponseMessage responseMessage2 = (ResponseMessage) voidMessage;
                this.replies.putIfAbsent(responseMessage2.getRequestId(), responseMessage2);
            } else if (voidMessage instanceof ResponseMessage) {
                ResponseMessage responseMessage3 = (ResponseMessage) voidMessage;
                this.replies.putIfAbsent(responseMessage3.getRequestId(), responseMessage3);
            } else if (voidMessage instanceof MeshUpdateMessage) {
                MeshOrganizer mesh2 = ((MeshUpdateMessage) voidMessage).getMesh();
                this.mesh.cas((Serializable) null, mesh2);
                synchronized (this.mesh) {
                    if (((MeshOrganizer) this.mesh.get()).getVersion() < mesh2.getVersion()) {
                        this.mesh.set(mesh2);
                    }
                }
                onMeshUpdate(mesh2);
            } else {
                if (!(voidMessage instanceof RequestMessage)) {
                    throw new ND4JIllegalStateException("Unknown message received: [" + voidMessage.getClass().getCanonicalName() + "]");
                }
                if (this.consumers.get(voidMessage.getClass().getCanonicalName()) == null) {
                    throw new ND4JIllegalStateException("Not supported RequestMessage received: [" + voidMessage.getClass().getCanonicalName() + "]");
                }
            }
        }
        if (voidMessage instanceof BroadcastableMessage) {
            try {
                if (this.numerOfNodes.get() > 0) {
                    propagateBroadcastableMessage((BroadcastableMessage) voidMessage, PropagationMode.BOTH_WAYS);
                } else {
                    log.info("Skipping broadcast due to absence of nodes in mesh");
                }
            } catch (Exception e2) {
                log.error("Wasn't able to propagate message [{}] from [{}]", voidMessage.getClass().getSimpleName(), voidMessage.getOriginatorId());
                log.error("BroadcastableMessage propagation exception:", e2);
                throw new RuntimeException(e2);
            }
        }
        if (!(voidMessage instanceof RequestMessage) || (consumer = this.consumers.get(voidMessage.getClass().getCanonicalName())) == null) {
            return;
        }
        try {
            consumer.accept(voidMessage);
        } catch (Exception e3) {
            throw new RuntimeException(e3);
        }
    }

    public void propagateMessageDirect(@NonNull BroadcastableMessage broadcastableMessage) {
        if (broadcastableMessage == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        synchronized (this.mesh) {
            ((MeshOrganizer) this.mesh.get()).flatNodes().stream().forEach(node -> {
                if (node.isRootNode()) {
                    return;
                }
                sendMessage(broadcastableMessage, node.getId());
            });
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void processMessage(VoidMessage voidMessage) {
        try {
            this.messageQueue.transfer(voidMessage);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public String getRandomDownstreamFrom(@NonNull String str, String str2) {
        if (str == null) {
            throw new NullPointerException("id is marked non-null but is null");
        }
        Collection<MeshOrganizer.Node> downstreamsForNode = ((MeshOrganizer) this.mesh.get()).getDownstreamsForNode(str);
        if (downstreamsForNode.isEmpty()) {
            return null;
        }
        ArrayList arrayList = new ArrayList((Collection) downstreamsForNode.stream().map(node -> {
            return node.getId();
        }).collect(Collectors.toList()));
        if (str2 != null) {
            arrayList.remove(str2);
        }
        if (arrayList.size() > 1) {
            Collections.shuffle(arrayList);
        }
        return (String) arrayList.get(0);
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public <T extends ResponseMessage> T sendMessageBlocking(@NonNull RequestMessage requestMessage, @NonNull String str) throws InterruptedException {
        if (requestMessage == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("id is marked non-null but is null");
        }
        if (requestMessage.getRequestId() == null) {
            requestMessage.setRequestId(UUID.randomUUID().toString());
        }
        sendMessage(requestMessage, str);
        while (true) {
            T t = (T) this.replies.get(requestMessage.getRequestId());
            if (t != null) {
                this.replies.remove(requestMessage.getRequestId());
                return t;
            }
            Thread.sleep(10L);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public <T extends ResponseMessage> T sendMessageBlocking(@NonNull RequestMessage requestMessage, @NonNull String str, long j, @NonNull TimeUnit timeUnit) throws InterruptedException {
        T t;
        if (requestMessage == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("id is marked non-null but is null");
        }
        if (timeUnit == null) {
            throw new NullPointerException("timeUnit is marked non-null but is null");
        }
        if (requestMessage.getRequestId() == null) {
            requestMessage.setRequestId(UUID.randomUUID().toString());
        }
        sendMessage(requestMessage, str);
        long convert = TimeUnit.MILLISECONDS.convert(j, timeUnit);
        long currentTimeMillis = System.currentTimeMillis();
        while (true) {
            t = (T) this.replies.get(requestMessage.getRequestId());
            if (t != null || System.currentTimeMillis() - currentTimeMillis > convert) {
                break;
            }
            LockSupport.parkNanos(5000L);
        }
        this.replies.remove(requestMessage.getRequestId());
        return t;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void setRestartCallback(RestartCallback restartCallback) {
        this.restartCallback = restartCallback;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public <T extends RequestMessage> void addRequestConsumer(@NonNull Class<T> cls, Consumer<T> consumer) {
        if (cls == null) {
            throw new NullPointerException("cls is marked non-null but is null");
        }
        if (consumer == null) {
            this.consumers.remove(cls.getCanonicalName());
        } else {
            this.consumers.put(cls.getCanonicalName(), consumer);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void onMeshUpdate(MeshOrganizer meshOrganizer) {
        this.numerOfNodes.set((int) meshOrganizer.totalNodes());
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void onRemap(String str) {
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public String getRootId() {
        return this.rootId;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public int totalNumberOfNodes() {
        return this.numerOfNodes.get();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public boolean isConnected() {
        return true;
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public boolean isIntroduced() {
        if (this.masterMode) {
            return true;
        }
        return this.handshakeFlag.get();
    }

    @Override // org.nd4j.parameterserver.distributed.v2.transport.Transport
    public void ensureConnection(String str) {
    }
}
