package org.deeplearning4j.ui.play;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
import org.deeplearning4j.ui.module.flow.FlowListenerModule;
import org.deeplearning4j.ui.module.histogram.HistogramModule;
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.module.tsne.TsneModule;
import org.deeplearning4j.ui.play.misc.FunctionUtil;
import org.deeplearning4j.ui.play.staticroutes.Assets;
import org.deeplearning4j.ui.play.staticroutes.I18NRoute;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener;
import org.reflections.ReflectionUtils;
import org.reflections.Reflections;
import org.reflections.scanners.SubTypesScanner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import play.Mode;
import play.mvc.Results;
import play.routing.RoutingDsl;
import play.server.Server;

/* loaded from: input_file:org/deeplearning4j/ui/play/PlayUIServer.class */
public class PlayUIServer extends UIServer {
    private static final Logger log = LoggerFactory.getLogger(PlayUIServer.class);
    public static final String UI_SERVER_PORT_PROPERTY = "org.deeplearning4j.ui.port";
    public static final int DEFAULT_UI_PORT = 9000;
    public static final String UI_CUSTOM_MODULE_PROPERTY = "org.deeplearning4j.ui.custommodule.enable";
    public static final String ASSETS_ROOT_DIRECTORY = "deeplearning4jUiAssets/";
    private Server server;
    private final BlockingQueue<StatsStorageEvent> eventQueue;
    private List<Pair<StatsStorage, StatsStorageListener>> listeners;
    private List<StatsStorage> statsStorageInstances;
    private List<UIModule> uiModules;
    private RemoteReceiverModule remoteReceiverModule;
    private Map<String, List<UIModule>> typeIDModuleMap;
    private long uiProcessingDelay;
    private final AtomicBoolean shutdown;
    private Thread uiEventRoutingThread;

    @Parameter(names = {"-r", "-enableRemote"}, description = "Whether to enable remote or not", arity = 1)
    private boolean enableRemote;

    @Parameter(names = {"--uiPort"}, description = "Whether to enable remote or not", arity = 1)
    private int port;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/ui/play/PlayUIServer$StatsEventRouterRunnable.class */
    public class StatsEventRouterRunnable implements Runnable {
        private StatsEventRouterRunnable() {
        }

        @Override // java.lang.Runnable
        public void run() {
            try {
                runHelper();
            } catch (Exception e) {
                PlayUIServer.log.error("Unexpected exception from Event routing runnable", e);
            }
        }

        private void runHelper() throws Exception {
            PlayUIServer.log.debug("PlayUIServer.StatsEventRouterRunnable started");
            while (!PlayUIServer.this.shutdown.get()) {
                ArrayList<StatsStorageEvent> arrayList = new ArrayList();
                arrayList.add((StatsStorageEvent) PlayUIServer.this.eventQueue.take());
                PlayUIServer.this.eventQueue.drainTo(arrayList);
                for (UIModule uIModule : PlayUIServer.this.uiModules) {
                    List<String> callbackTypeIDs = uIModule.getCallbackTypeIDs();
                    ArrayList arrayList2 = new ArrayList();
                    for (StatsStorageEvent statsStorageEvent : arrayList) {
                        if (callbackTypeIDs.contains(statsStorageEvent.getTypeID())) {
                            arrayList2.add(statsStorageEvent);
                        }
                    }
                    uIModule.reportStorageEvents(arrayList2);
                }
                try {
                    Thread.sleep(PlayUIServer.this.uiProcessingDelay);
                } catch (InterruptedException e) {
                    if (!PlayUIServer.this.shutdown.get()) {
                        throw new RuntimeException("Unexpected interrupted exception", e);
                    }
                }
            }
        }
    }

    public PlayUIServer() {
        this(DEFAULT_UI_PORT);
    }

    public PlayUIServer(int i) {
        this.eventQueue = new LinkedBlockingQueue();
        this.listeners = new ArrayList();
        this.statsStorageInstances = new ArrayList();
        this.uiModules = new ArrayList();
        this.typeIDModuleMap = new ConcurrentHashMap();
        this.uiProcessingDelay = 500L;
        this.shutdown = new AtomicBoolean(false);
        this.port = DEFAULT_UI_PORT;
        this.port = i;
    }

    public void runMain(String[] strArr) {
        JCommander jCommander = new JCommander(this);
        try {
            jCommander.parse(strArr);
        } catch (ParameterException e) {
            jCommander.usage();
            try {
                Thread.sleep(500L);
            } catch (Exception e2) {
            }
            System.exit(1);
        }
        RoutingDsl routingDsl = new RoutingDsl();
        routingDsl.GET("/setlang/:to").routeTo(FunctionUtil.function(new I18NRoute()));
        routingDsl.GET("/lang/getCurrent").routeTo(() -> {
            return Results.ok(I18NProvider.getInstance().getDefaultLanguage());
        });
        routingDsl.GET("/assets/*file").routeTo(FunctionUtil.function(new Assets(ASSETS_ROOT_DIRECTORY)));
        this.uiModules.add(new DefaultModule());
        this.uiModules.add(new HistogramModule());
        this.uiModules.add(new TrainModule());
        this.uiModules.add(new ConvolutionalListenerModule());
        this.uiModules.add(new FlowListenerModule());
        this.uiModules.add(new TsneModule());
        this.remoteReceiverModule = new RemoteReceiverModule();
        this.uiModules.add(this.remoteReceiverModule);
        this.uiModules.addAll(modulesViaServiceLoader());
        String property = System.getProperty(UI_CUSTOM_MODULE_PROPERTY);
        if (property != null ? Boolean.parseBoolean(property) : false) {
            ArrayList arrayList = new ArrayList();
            Iterator<UIModule> it = this.uiModules.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getClass());
            }
            this.uiModules.addAll(getCustomUIModules(arrayList));
        }
        for (UIModule uIModule : this.uiModules) {
            for (Route route : uIModule.getRoutes()) {
                RoutingDsl.PathPatternMatcher match = routingDsl.match(route.getHttpMethod().name(), route.getRoute());
                switch (route.getFunctionType()) {
                    case Supplier:
                        match.routeTo(FunctionUtil.function0(route.getSupplier()));
                        break;
                    case Function:
                        match.routeTo(FunctionUtil.function(route.getFunction()));
                        break;
                    case BiFunction:
                    case Function3:
                    default:
                        throw new RuntimeException("Not yet implemented");
                }
            }
            for (String str : uIModule.getCallbackTypeIDs()) {
                List<UIModule> list = this.typeIDModuleMap.get(str);
                if (list == null) {
                    list = Collections.synchronizedList(new ArrayList());
                    this.typeIDModuleMap.put(str, list);
                }
                list.add(uIModule);
            }
        }
        String property2 = System.getProperty(UI_SERVER_PORT_PROPERTY);
        if (property2 != null) {
            try {
                this.port = Integer.parseInt(property2);
            } catch (NumberFormatException e3) {
                log.warn("Could not parse UI port property \"{}\" with value \"{}\"", new Object[]{UI_SERVER_PORT_PROPERTY, property2, e3});
            }
        }
        this.server = Server.forRouter(routingDsl.build(), Mode.DEV, this.port);
        this.port = this.port;
        log.info("DL4J UI Server started at {}", getAddress());
        this.uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
        this.uiEventRoutingThread.setDaemon(true);
        this.uiEventRoutingThread.start();
        if (this.enableRemote) {
            enableRemoteListener();
        }
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public String getAddress() {
        int lastIndexOf;
        String inetSocketAddress = this.server.mainAddress().toString();
        if (inetSocketAddress.startsWith("/0:0:0:0:0:0:0:0") && (lastIndexOf = inetSocketAddress.lastIndexOf(58)) > 0) {
            inetSocketAddress = "http://localhost:" + inetSocketAddress.substring(lastIndexOf + 1);
        }
        return inetSocketAddress;
    }

    private List<UIModule> modulesViaServiceLoader() {
        Iterator it = ServiceLoader.load(UIModule.class).iterator();
        if (!it.hasNext()) {
            return Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            UIModule uIModule = (UIModule) it.next();
            log.info("Loaded UI module via service loader: {}", uIModule.getClass());
            arrayList.add(uIModule);
        }
        return arrayList;
    }

    public static void main(String[] strArr) {
        new PlayUIServer().runMain(strArr);
    }

    private List<UIModule> getCustomUIModules(List<Class<?>> list) {
        HashSet<Class> newHashSet = Sets.newHashSet(ReflectionUtils.forNames(new Reflections(new Object[0]).getStore().getAll(SubTypesScanner.class.getSimpleName(), Collections.singletonList(UIModule.class.getName())), new ClassLoader[0]));
        ArrayList<Class> arrayList = new ArrayList();
        for (Class cls : newHashSet) {
            if (!list.contains(cls)) {
                arrayList.add(cls);
            }
        }
        ArrayList arrayList2 = new ArrayList(arrayList.size());
        for (Class cls2 : arrayList) {
            try {
                UIModule uIModule = (UIModule) cls2.newInstance();
                log.debug("Created instance of custom UI module: {}", cls2);
                arrayList2.add(uIModule);
            } catch (Exception e) {
                log.warn("Could not create instance of custom UIModule of type {}; skipping", cls2, e);
            }
        }
        return arrayList2;
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public int getPort() {
        return this.port;
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public synchronized void attach(StatsStorage statsStorage) {
        if (statsStorage == null) {
            throw new IllegalArgumentException("StatsStorage cannot be null");
        }
        if (this.statsStorageInstances.contains(statsStorage)) {
            return;
        }
        QueueStatsStorageListener queueStatsStorageListener = new QueueStatsStorageListener(this.eventQueue);
        this.listeners.add(new Pair<>(statsStorage, queueStatsStorageListener));
        statsStorage.registerStatsStorageListener(queueStatsStorageListener);
        this.statsStorageInstances.add(statsStorage);
        Iterator<UIModule> it = this.uiModules.iterator();
        while (it.hasNext()) {
            it.next().onAttach(statsStorage);
        }
        log.info("StatsStorage instance attached to UI: {}", statsStorage);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public synchronized void detach(StatsStorage statsStorage) {
        if (statsStorage == null) {
            throw new IllegalArgumentException("StatsStorage cannot be null");
        }
        if (this.statsStorageInstances.contains(statsStorage)) {
            boolean z = false;
            for (Pair<StatsStorage, StatsStorageListener> pair : this.listeners) {
                if (pair.getFirst() == statsStorage) {
                    statsStorage.deregisterStatsStorageListener((StatsStorageListener) pair.getSecond());
                    this.listeners.remove(pair);
                    z = true;
                }
            }
            Iterator<UIModule> it = this.uiModules.iterator();
            while (it.hasNext()) {
                it.next().onDetach(statsStorage);
            }
            if (z) {
                log.info("StatsStorage instance detached from UI: {}", statsStorage);
            }
        }
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public boolean isAttached(StatsStorage statsStorage) {
        return this.statsStorageInstances.contains(statsStorage);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public List<StatsStorage> getStatsStorageInstances() {
        return new ArrayList(this.statsStorageInstances);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void enableRemoteListener() {
        if (this.remoteReceiverModule == null) {
            this.remoteReceiverModule = new RemoteReceiverModule();
        }
        if (this.remoteReceiverModule.isEnabled()) {
            return;
        }
        enableRemoteListener(new InMemoryStatsStorage(), true);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void enableRemoteListener(StatsStorageRouter statsStorageRouter, boolean z) {
        this.remoteReceiverModule.setEnabled(true);
        this.remoteReceiverModule.setStatsStorage(statsStorageRouter);
        if (z && (statsStorageRouter instanceof StatsStorage)) {
            attach((StatsStorage) statsStorageRouter);
        }
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void disableRemoteListener() {
        this.remoteReceiverModule.setEnabled(false);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public boolean isRemoteListenerEnabled() {
        return this.remoteReceiverModule.isEnabled();
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void stop() {
        if (this.server != null) {
            this.server.stop();
        }
    }

    public Server getServer() {
        return this.server;
    }

    public BlockingQueue<StatsStorageEvent> getEventQueue() {
        return this.eventQueue;
    }

    public List<Pair<StatsStorage, StatsStorageListener>> getListeners() {
        return this.listeners;
    }

    public List<UIModule> getUiModules() {
        return this.uiModules;
    }

    public RemoteReceiverModule getRemoteReceiverModule() {
        return this.remoteReceiverModule;
    }

    public Map<String, List<UIModule>> getTypeIDModuleMap() {
        return this.typeIDModuleMap;
    }

    public long getUiProcessingDelay() {
        return this.uiProcessingDelay;
    }

    public AtomicBoolean getShutdown() {
        return this.shutdown;
    }

    public Thread getUiEventRoutingThread() {
        return this.uiEventRoutingThread;
    }

    public boolean isEnableRemote() {
        return this.enableRemote;
    }

    public void setServer(Server server) {
        this.server = server;
    }

    public void setListeners(List<Pair<StatsStorage, StatsStorageListener>> list) {
        this.listeners = list;
    }

    public void setStatsStorageInstances(List<StatsStorage> list) {
        this.statsStorageInstances = list;
    }

    public void setUiModules(List<UIModule> list) {
        this.uiModules = list;
    }

    public void setRemoteReceiverModule(RemoteReceiverModule remoteReceiverModule) {
        this.remoteReceiverModule = remoteReceiverModule;
    }

    public void setTypeIDModuleMap(Map<String, List<UIModule>> map) {
        this.typeIDModuleMap = map;
    }

    public void setUiProcessingDelay(long j) {
        this.uiProcessingDelay = j;
    }

    public void setUiEventRoutingThread(Thread thread) {
        this.uiEventRoutingThread = thread;
    }

    public void setEnableRemote(boolean z) {
        this.enableRemote = z;
    }

    public void setPort(int i) {
        this.port = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof PlayUIServer)) {
            return false;
        }
        PlayUIServer playUIServer = (PlayUIServer) obj;
        if (!playUIServer.canEqual(this)) {
            return false;
        }
        Server server = getServer();
        Server server2 = playUIServer.getServer();
        if (server == null) {
            if (server2 != null) {
                return false;
            }
        } else if (!server.equals(server2)) {
            return false;
        }
        BlockingQueue<StatsStorageEvent> eventQueue = getEventQueue();
        BlockingQueue<StatsStorageEvent> eventQueue2 = playUIServer.getEventQueue();
        if (eventQueue == null) {
            if (eventQueue2 != null) {
                return false;
            }
        } else if (!eventQueue.equals(eventQueue2)) {
            return false;
        }
        List<Pair<StatsStorage, StatsStorageListener>> listeners = getListeners();
        List<Pair<StatsStorage, StatsStorageListener>> listeners2 = playUIServer.getListeners();
        if (listeners == null) {
            if (listeners2 != null) {
                return false;
            }
        } else if (!listeners.equals(listeners2)) {
            return false;
        }
        List<StatsStorage> statsStorageInstances = getStatsStorageInstances();
        List<StatsStorage> statsStorageInstances2 = playUIServer.getStatsStorageInstances();
        if (statsStorageInstances == null) {
            if (statsStorageInstances2 != null) {
                return false;
            }
        } else if (!statsStorageInstances.equals(statsStorageInstances2)) {
            return false;
        }
        List<UIModule> uiModules = getUiModules();
        List<UIModule> uiModules2 = playUIServer.getUiModules();
        if (uiModules == null) {
            if (uiModules2 != null) {
                return false;
            }
        } else if (!uiModules.equals(uiModules2)) {
            return false;
        }
        RemoteReceiverModule remoteReceiverModule = getRemoteReceiverModule();
        RemoteReceiverModule remoteReceiverModule2 = playUIServer.getRemoteReceiverModule();
        if (remoteReceiverModule == null) {
            if (remoteReceiverModule2 != null) {
                return false;
            }
        } else if (!remoteReceiverModule.equals(remoteReceiverModule2)) {
            return false;
        }
        Map<String, List<UIModule>> typeIDModuleMap = getTypeIDModuleMap();
        Map<String, List<UIModule>> typeIDModuleMap2 = playUIServer.getTypeIDModuleMap();
        if (typeIDModuleMap == null) {
            if (typeIDModuleMap2 != null) {
                return false;
            }
        } else if (!typeIDModuleMap.equals(typeIDModuleMap2)) {
            return false;
        }
        if (getUiProcessingDelay() != playUIServer.getUiProcessingDelay()) {
            return false;
        }
        AtomicBoolean shutdown = getShutdown();
        AtomicBoolean shutdown2 = playUIServer.getShutdown();
        if (shutdown == null) {
            if (shutdown2 != null) {
                return false;
            }
        } else if (!shutdown.equals(shutdown2)) {
            return false;
        }
        Thread uiEventRoutingThread = getUiEventRoutingThread();
        Thread uiEventRoutingThread2 = playUIServer.getUiEventRoutingThread();
        if (uiEventRoutingThread == null) {
            if (uiEventRoutingThread2 != null) {
                return false;
            }
        } else if (!uiEventRoutingThread.equals(uiEventRoutingThread2)) {
            return false;
        }
        return isEnableRemote() == playUIServer.isEnableRemote() && getPort() == playUIServer.getPort();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof PlayUIServer;
    }

    public int hashCode() {
        Server server = getServer();
        int hashCode = (1 * 59) + (server == null ? 43 : server.hashCode());
        BlockingQueue<StatsStorageEvent> eventQueue = getEventQueue();
        int hashCode2 = (hashCode * 59) + (eventQueue == null ? 43 : eventQueue.hashCode());
        List<Pair<StatsStorage, StatsStorageListener>> listeners = getListeners();
        int hashCode3 = (hashCode2 * 59) + (listeners == null ? 43 : listeners.hashCode());
        List<StatsStorage> statsStorageInstances = getStatsStorageInstances();
        int hashCode4 = (hashCode3 * 59) + (statsStorageInstances == null ? 43 : statsStorageInstances.hashCode());
        List<UIModule> uiModules = getUiModules();
        int hashCode5 = (hashCode4 * 59) + (uiModules == null ? 43 : uiModules.hashCode());
        RemoteReceiverModule remoteReceiverModule = getRemoteReceiverModule();
        int hashCode6 = (hashCode5 * 59) + (remoteReceiverModule == null ? 43 : remoteReceiverModule.hashCode());
        Map<String, List<UIModule>> typeIDModuleMap = getTypeIDModuleMap();
        int hashCode7 = (hashCode6 * 59) + (typeIDModuleMap == null ? 43 : typeIDModuleMap.hashCode());
        long uiProcessingDelay = getUiProcessingDelay();
        int i = (hashCode7 * 59) + ((int) ((uiProcessingDelay >>> 32) ^ uiProcessingDelay));
        AtomicBoolean shutdown = getShutdown();
        int hashCode8 = (i * 59) + (shutdown == null ? 43 : shutdown.hashCode());
        Thread uiEventRoutingThread = getUiEventRoutingThread();
        return (((((hashCode8 * 59) + (uiEventRoutingThread == null ? 43 : uiEventRoutingThread.hashCode())) * 59) + (isEnableRemote() ? 79 : 97)) * 59) + getPort();
    }

    public String toString() {
        return "PlayUIServer(server=" + getServer() + ", eventQueue=" + getEventQueue() + ", listeners=" + getListeners() + ", statsStorageInstances=" + getStatsStorageInstances() + ", uiModules=" + getUiModules() + ", remoteReceiverModule=" + getRemoteReceiverModule() + ", typeIDModuleMap=" + getTypeIDModuleMap() + ", uiProcessingDelay=" + getUiProcessingDelay() + ", shutdown=" + getShutdown() + ", uiEventRoutingThread=" + getUiEventRoutingThread() + ", enableRemote=" + isEnableRemote() + ", port=" + getPort() + ")";
    }
}
