package org.deeplearning4j.scaleout.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Cancellable;
import akka.actor.OneForOneStrategy;
import akka.actor.PoisonPill;
import akka.actor.SupervisorStrategy;
import akka.actor.UntypedActor;
import akka.contrib.pattern.ClusterSingletonManager;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import akka.event.Logging;
import akka.event.LoggingAdapter;
import akka.japi.Function;
import akka.routing.RoundRobinPool;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.actor.core.protocol.Ack;
import org.deeplearning4j.scaleout.actor.util.ActorRefUtils;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.api.workrouter.WorkRouter;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.messages.DoneMessage;
import org.deeplearning4j.scaleout.messages.MoreWorkMessage;
import org.deeplearning4j.scaleout.perform.WorkerPerformer;
import org.deeplearning4j.scaleout.perform.WorkerPerformerFactory;
import scala.Option;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/scaleout/actor/core/actor/MasterActor.class */
public class MasterActor extends UntypedActor implements DeepLearningConfigurable {
    protected Configuration conf;
    protected ActorRef batchActor;
    protected StateTracker stateTracker;
    public static String BROADCAST = "broadcast";
    public static String MASTER = "result";
    public static String SHUTDOWN = "shutdown";
    public static String FINISH = "finish";
    public static final String NAME_SPACE = "org.deeplearning4j.scaleout.actor.core.actor";
    public static final String POLL_FOR_WORK = "org.deeplearning4j.scaleout.actor.core.actor.poll";
    protected Cancellable forceNextPhase;
    protected Cancellable clearStateWorkers;
    protected WorkRouter workRouter;
    protected LoggingAdapter log = Logging.getLogger(getContext().system(), this);
    protected final ActorRef mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
    protected int secondsPoll = 1;
    protected AtomicBoolean doneCalled = new AtomicBoolean(false);

    public MasterActor(Configuration configuration, ActorRef actorRef, final StateTracker stateTracker, WorkRouter workRouter) {
        this.conf = configuration;
        this.batchActor = actorRef;
        this.workRouter = workRouter;
        this.stateTracker = stateTracker;
        setup(configuration);
        stateTracker.runPreTrainIterations(configuration.getInt("org.deeplearning4j.numpasses", 1));
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MASTER, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(FINISH, getSelf()), getSelf());
        this.forceNextPhase = context().system().scheduler().schedule(Duration.create(this.secondsPoll, TimeUnit.SECONDS), Duration.create(this.secondsPoll, TimeUnit.SECONDS), new Runnable() { // from class: org.deeplearning4j.scaleout.actor.core.actor.MasterActor.1
            @Override // java.lang.Runnable
            public void run() {
                MasterActor.this.log.info("Heart beat on " + stateTracker.workers().size() + " workers");
                if (stateTracker.isDone()) {
                    return;
                }
                if (MasterActor.this.workRouter.sendWork()) {
                    MasterActor.this.nextBatch();
                }
                try {
                    HashSet hashSet = new HashSet();
                    for (Job job : stateTracker.currentJobs()) {
                        if (stateTracker.recentlyCleared().contains(job.workerId())) {
                            stateTracker.clearJob(job.workerId());
                            hashSet.add(job);
                            MasterActor.this.log.info("Found job that wasn't clear " + job.workerId());
                        }
                    }
                    stateTracker.currentJobs().removeAll(hashSet);
                    if (stateTracker.currentJobs().isEmpty()) {
                        stateTracker.recentlyCleared().clear();
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }, context().dispatcher());
        this.clearStateWorkers = context().system().scheduler().schedule(Duration.create(1L, TimeUnit.MINUTES), Duration.create(1L, TimeUnit.MINUTES), new Runnable() { // from class: org.deeplearning4j.scaleout.actor.core.actor.MasterActor.2
            @Override // java.lang.Runnable
            public void run() {
                if (stateTracker.isDone()) {
                    return;
                }
                try {
                    long currentTimeMillis = System.currentTimeMillis();
                    for (Map.Entry entry : MasterActor.this.stateTracker.getHeartBeats().entrySet()) {
                        String str = (String) entry.getKey();
                        if (TimeUnit.MILLISECONDS.toSeconds(currentTimeMillis - ((Long) entry.getValue()).longValue()) >= 120) {
                            MasterActor.this.log.info("Removing stale worker " + str);
                            MasterActor.this.stateTracker.removeWorker(str);
                        }
                    }
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }, context().dispatcher());
    }

    public void setup(Configuration configuration) {
        WorkerPerformerFactory workerPerformerFactory;
        this.log.info("Starting workers");
        ActorSystem system = context().system();
        RoundRobinPool roundRobinPool = new RoundRobinPool(Runtime.getRuntime().availableProcessors());
        try {
            Class<?> cls = Class.forName(configuration.get("org.deeplearning4j.scaleout.perform.workerperformer"));
            try {
                workerPerformerFactory = (WorkerPerformerFactory) cls.getConstructor(StateTracker.class).newInstance(this.stateTracker);
            } catch (NoSuchMethodException e) {
                workerPerformerFactory = (WorkerPerformerFactory) cls.newInstance();
            }
            WorkerPerformer create = workerPerformerFactory.create(configuration);
            this.secondsPoll = configuration.getInt(POLL_FOR_WORK, 10);
            system.actorOf(ClusterSingletonManager.defaultProps(roundRobinPool.props(WorkerActor.propsFor(configuration, this.stateTracker, create)), "master", PoisonPill.getInstance(), "master"), "worker");
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    public void onReceive(Object obj) throws Exception {
        if ((obj instanceof DistributedPubSubMediator.SubscribeAck) || (obj instanceof DistributedPubSubMediator.UnsubscribeAck)) {
            this.mediator.tell(new DistributedPubSubMediator.Publish("topics", obj), getSelf());
            this.log.info("Subscribed " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
            return;
        }
        if (obj instanceof DoneMessage) {
            this.log.info("Received done message");
            doDoneOrNextPhase();
        } else if (obj instanceof String) {
            getSender().tell(Ack.getInstance(), getSelf());
        } else if (!(obj instanceof MoreWorkMessage)) {
            unhandled(obj);
        } else {
            this.log.info("Prompted for more work, starting pipeline");
            this.mediator.tell(new DistributedPubSubMediator.Publish(BatchActor.BATCH, MoreWorkMessage.getInstance()), getSelf());
        }
    }

    protected void nextBatch() {
        Collection workerUpdates = this.stateTracker.workerUpdates();
        try {
            List currentJobs = this.stateTracker.currentJobs();
            if (!workerUpdates.isEmpty() && currentJobs.isEmpty()) {
                this.workRouter.update();
                ActorRefUtils.throwExceptionIfExists(Futures.future(new Callable<Void>() { // from class: org.deeplearning4j.scaleout.actor.core.actor.MasterActor.3
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Void call() throws Exception {
                        MasterActor.this.mediator.tell(new DistributedPubSubMediator.Publish(BatchActor.BATCH, MoreWorkMessage.getInstance()), MasterActor.this.getSelf());
                        MasterActor.this.log.info("Requesting more work...");
                        return null;
                    }
                }, context().dispatcher()), context().dispatcher());
            } else if (currentJobs.isEmpty()) {
                this.stateTracker.finish();
                this.stateTracker.shutdown();
                context().system().shutdown();
                this.log.info("Current jobs is empty and no more updates; terminating");
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected void doDoneOrNextPhase() throws Exception {
        if (!this.stateTracker.workerUpdates().isEmpty()) {
            this.workRouter.update();
        }
        if (!this.stateTracker.currentJobs().isEmpty() || this.doneCalled.get()) {
            return;
        }
        this.doneCalled.set(true);
        nextBatch();
        this.stateTracker.finish();
        this.log.info("Done training!");
    }

    public void aroundPostRestart(Throwable th) {
        super.aroundPostRestart(th);
        this.log.info("Restarted because of ", th);
    }

    public void aroundPreRestart(Throwable th, Option<Object> option) {
        super.aroundPreRestart(th, option);
        this.log.info("Restarted because of ", th + " with message " + option.toString());
    }

    public void preStart() throws Exception {
        super.preStart();
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.log.info("Setup master with path " + self().path());
        this.log.info("Pre start on master " + self().path().toString());
    }

    public void postStop() throws Exception {
        super.postStop();
        this.log.info("Post stop on master");
        if (this.clearStateWorkers != null) {
            this.clearStateWorkers.cancel();
        }
        if (this.forceNextPhase != null) {
            this.forceNextPhase.cancel();
        }
    }

    public SupervisorStrategy supervisorStrategy() {
        return new OneForOneStrategy(0, Duration.Zero(), new Function<Throwable, SupervisorStrategy.Directive>() { // from class: org.deeplearning4j.scaleout.actor.core.actor.MasterActor.4
            public SupervisorStrategy.Directive apply(Throwable th) {
                MasterActor.this.log.error("Problem with processing", th);
                return SupervisorStrategy.resume();
            }
        });
    }
}
