package org.nd4j.parameterserver.distributed.logic.routing;

import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.messages.Frame;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.messages.VoidMessage;
import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage;
import org.nd4j.parameterserver.distributed.transport.Transport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
/* loaded from: input_file:org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouter.class */
public class InterleavedRouter extends BaseRouter {
    private static final Logger log = LoggerFactory.getLogger(InterleavedRouter.class);
    protected short targetIndex;
    protected AtomicLong counter;

    public InterleavedRouter() {
        this.targetIndex = (short) -1;
        this.counter = new AtomicLong(0L);
    }

    public InterleavedRouter(int i) {
        this();
        this.targetIndex = (short) i;
    }

    @Override // org.nd4j.parameterserver.distributed.logic.routing.BaseRouter, org.nd4j.parameterserver.distributed.logic.ClientRouter
    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked non-null but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked non-null but is null");
        }
        super.init(voidConfiguration, transport);
        if (this.targetIndex < 0) {
            this.targetIndex = (short) RandomUtils.nextInt(0, voidConfiguration.getNumberOfShards());
        }
    }

    @Override // org.nd4j.parameterserver.distributed.logic.ClientRouter
    public int assignTarget(TrainingMessage trainingMessage) {
        setOriginator(trainingMessage);
        if (trainingMessage instanceof SkipGramRequestMessage) {
            int w1 = ((SkipGramRequestMessage) trainingMessage).getW1();
            if (w1 >= this.voidConfiguration.getNumberOfShards()) {
                trainingMessage.setTargetId((short) (w1 % this.voidConfiguration.getNumberOfShards()));
            } else {
                trainingMessage.setTargetId((short) w1);
            }
        } else {
            trainingMessage.setTargetId((short) Math.abs(this.counter.incrementAndGet() % this.voidConfiguration.getNumberOfShards()));
        }
        return trainingMessage.getTargetId();
    }

    @Override // org.nd4j.parameterserver.distributed.logic.ClientRouter
    public int assignTarget(VoidMessage voidMessage) {
        setOriginator(voidMessage);
        if (voidMessage instanceof Frame) {
            voidMessage.setTargetId((short) Math.abs(this.counter.incrementAndGet() % this.voidConfiguration.getNumberOfShards()));
        } else {
            voidMessage.setTargetId(this.targetIndex);
        }
        return voidMessage.getTargetId();
    }
}
