package org.nd4j.parameterserver.node;

import io.aeron.Aeron;
import io.aeron.driver.MediaDriver;
import java.util.ArrayList;
import java.util.Arrays;
import org.agrona.CloseHelper;
import org.nd4j.aeron.ipc.AeronUtil;
import org.nd4j.aeron.ipc.NDArrayCallback;
import org.nd4j.parameterserver.ParameterServerListener;
import org.nd4j.parameterserver.ParameterServerSubscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/node/ParameterServerNode.class */
public class ParameterServerNode implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(ParameterServerNode.class);
    private ParameterServerSubscriber[] subscriber;
    private MediaDriver mediaDriver;
    private Aeron aeron;
    private int statusPort;
    private int numWorkers;

    public ParameterServerNode(MediaDriver mediaDriver, int i) {
        this(mediaDriver, i, Runtime.getRuntime().availableProcessors());
    }

    public ParameterServerNode(MediaDriver mediaDriver, int i, int i2) {
        this.mediaDriver = mediaDriver;
        this.statusPort = i;
        this.numWorkers = i2;
        this.subscriber = new ParameterServerSubscriber[i2];
    }

    public ParameterServerNode(MediaDriver mediaDriver) {
        this(mediaDriver, 9000);
    }

    public void runMain(String[] strArr) {
        if (this.mediaDriver == null) {
            this.mediaDriver = MediaDriver.launchEmbedded();
        }
        log.info("Started media driver with aeron directory " + this.mediaDriver.aeronDirectoryName());
        NDArrayCallback nDArrayCallback = null;
        ParameterServerListener parameterServerListener = null;
        for (int i = 0; i < this.numWorkers; i++) {
            this.subscriber[i] = new ParameterServerSubscriber(this.mediaDriver);
            if (this.aeron == null) {
                this.aeron = Aeron.connect(getContext(this.mediaDriver));
            }
            this.subscriber[i].setAeron(this.aeron);
            ArrayList arrayList = new ArrayList(Arrays.asList(strArr));
            if (arrayList.contains("-id")) {
                int indexOf = arrayList.indexOf("-id") + 1;
                arrayList.set(indexOf, String.valueOf(Integer.parseInt((String) arrayList.get(indexOf)) + i));
            } else if (arrayList.contains("--streamId")) {
                int indexOf2 = arrayList.indexOf("--streamId") + 1;
                arrayList.set(indexOf2, String.valueOf(Integer.parseInt((String) arrayList.get(indexOf2)) + i));
            }
            if (i == 0) {
                this.subscriber[i].run((String[]) arrayList.toArray(new String[strArr.length]));
                nDArrayCallback = this.subscriber[i].getCallback();
                parameterServerListener = this.subscriber[i].getParameterServerListener();
            } else {
                this.subscriber[i].setCallback(nDArrayCallback);
                this.subscriber[i].setParameterServerListener(parameterServerListener);
                this.subscriber[i].run((String[]) arrayList.toArray(new String[strArr.length]));
            }
        }
    }

    public boolean subscriberLaunched() {
        boolean z = true;
        for (int i = 0; i < this.numWorkers; i++) {
            z = z && this.subscriber[i].subscriberLaunched();
        }
        return z;
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        if (this.subscriber != null) {
            for (int i = 0; i < this.subscriber.length; i++) {
                if (this.subscriber[i] != null) {
                    this.subscriber[i].close();
                }
            }
        }
        if (this.mediaDriver != null) {
            CloseHelper.quietClose(this.mediaDriver);
        }
        if (this.aeron != null) {
            CloseHelper.quietClose(this.aeron);
        }
    }

    private static Aeron.Context getContext(MediaDriver mediaDriver) {
        return new Aeron.Context().availableImageHandler(AeronUtil::printAvailableImage).unavailableImageHandler(AeronUtil::printUnavailableImage).aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000L).errorHandler(th -> {
            log.error(th.toString(), th);
        });
    }

    public static void main(String[] strArr) {
        new ParameterServerNode().runMain(strArr);
    }

    public ParameterServerNode() {
    }

    public ParameterServerSubscriber[] getSubscriber() {
        return this.subscriber;
    }

    public MediaDriver getMediaDriver() {
        return this.mediaDriver;
    }

    public Aeron getAeron() {
        return this.aeron;
    }

    public int getStatusPort() {
        return this.statusPort;
    }

    public int getNumWorkers() {
        return this.numWorkers;
    }

    public void setSubscriber(ParameterServerSubscriber[] parameterServerSubscriberArr) {
        this.subscriber = parameterServerSubscriberArr;
    }

    public void setMediaDriver(MediaDriver mediaDriver) {
        this.mediaDriver = mediaDriver;
    }

    public void setAeron(Aeron aeron) {
        this.aeron = aeron;
    }

    public void setStatusPort(int i) {
        this.statusPort = i;
    }

    public void setNumWorkers(int i) {
        this.numWorkers = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParameterServerNode)) {
            return false;
        }
        ParameterServerNode parameterServerNode = (ParameterServerNode) obj;
        if (!parameterServerNode.canEqual(this) || getStatusPort() != parameterServerNode.getStatusPort() || getNumWorkers() != parameterServerNode.getNumWorkers() || !Arrays.deepEquals(getSubscriber(), parameterServerNode.getSubscriber())) {
            return false;
        }
        MediaDriver mediaDriver = getMediaDriver();
        MediaDriver mediaDriver2 = parameterServerNode.getMediaDriver();
        if (mediaDriver == null) {
            if (mediaDriver2 != null) {
                return false;
            }
        } else if (!mediaDriver.equals(mediaDriver2)) {
            return false;
        }
        Aeron aeron = getAeron();
        Aeron aeron2 = parameterServerNode.getAeron();
        return aeron == null ? aeron2 == null : aeron.equals(aeron2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParameterServerNode;
    }

    public int hashCode() {
        int statusPort = (((((1 * 59) + getStatusPort()) * 59) + getNumWorkers()) * 59) + Arrays.deepHashCode(getSubscriber());
        MediaDriver mediaDriver = getMediaDriver();
        int hashCode = (statusPort * 59) + (mediaDriver == null ? 43 : mediaDriver.hashCode());
        Aeron aeron = getAeron();
        return (hashCode * 59) + (aeron == null ? 43 : aeron.hashCode());
    }

    public String toString() {
        return "ParameterServerNode(subscriber=" + Arrays.deepToString(getSubscriber()) + ", mediaDriver=" + getMediaDriver() + ", aeron=" + getAeron() + ", statusPort=" + getStatusPort() + ", numWorkers=" + getNumWorkers() + ")";
    }
}
