package org.deeplearning4j.spark.impl;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/impl/SparkListenable.class */
public class SparkListenable {
    private static final Logger log = LoggerFactory.getLogger(SparkListenable.class);
    protected TrainingMaster trainingMaster;
    private List<TrainingListener> listeners = new ArrayList();

    public void setListeners(@NonNull Collection<TrainingListener> collection) {
        if (collection == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        this.listeners.clear();
        this.listeners.addAll(collection);
        if (this.trainingMaster != null) {
            this.trainingMaster.setListeners(this.listeners);
        }
    }

    public void setListeners(@NonNull TrainingListener... trainingListenerArr) {
        if (trainingListenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        setListeners(Arrays.asList(trainingListenerArr));
    }

    public void setListeners(StatsStorageRouter statsStorageRouter, TrainingListener... trainingListenerArr) {
        setListeners(statsStorageRouter, Arrays.asList(trainingListenerArr));
    }

    public void setListeners(StatsStorageRouter statsStorageRouter, Collection<? extends TrainingListener> collection) {
        VanillaStatsStorageRouterProvider vanillaStatsStorageRouterProvider = null;
        if (collection != null) {
            Iterator<? extends TrainingListener> it = collection.iterator();
            while (it.hasNext()) {
                RoutingIterationListener routingIterationListener = (TrainingListener) it.next();
                if (routingIterationListener instanceof RoutingIterationListener) {
                    RoutingIterationListener routingIterationListener2 = routingIterationListener;
                    if (statsStorageRouter == null && routingIterationListener2.getStorageRouter() == null) {
                        log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", routingIterationListener);
                    } else if (routingIterationListener2.getStorageRouter() != null && !(routingIterationListener2.getStorageRouter() instanceof Serializable)) {
                        throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router \nRoutingIterationListener class: " + routingIterationListener2.getClass().getName() + "\nStatsStorageRouter class: " + routingIterationListener2.getStorageRouter().getClass().getName());
                    }
                    if (vanillaStatsStorageRouterProvider == null) {
                        vanillaStatsStorageRouterProvider = new VanillaStatsStorageRouterProvider();
                    }
                }
            }
        }
        this.listeners.clear();
        if (collection != null) {
            this.listeners.addAll(collection);
            if (this.trainingMaster != null) {
                this.trainingMaster.setListeners(statsStorageRouter, this.listeners);
            }
        }
    }
}
