package org.deeplearning4j.iterativereduce.runtime.yarn.appworker;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.avro.AvroRemoteException;
import org.apache.avro.ipc.NettyTransceiver;
import org.apache.avro.ipc.specific.SpecificRequestor;
import org.apache.hadoop.conf.Configuration;
import org.deeplearning4j.iterativereduce.impl.reader.CanovaRecordReader;
import org.deeplearning4j.iterativereduce.runtime.ComputableWorker;
import org.deeplearning4j.iterativereduce.runtime.Utils;
import org.deeplearning4j.iterativereduce.runtime.yarn.avro.generated.IterativeReduceService;
import org.deeplearning4j.iterativereduce.runtime.yarn.avro.generated.ProgressReport;
import org.deeplearning4j.iterativereduce.runtime.yarn.avro.generated.ServiceError;
import org.deeplearning4j.iterativereduce.runtime.yarn.avro.generated.StartupConfiguration;
import org.deeplearning4j.iterativereduce.runtime.yarn.avro.generated.WorkerId;
import org.deeplearning4j.scaleout.api.ir.Updateable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/runtime/yarn/appworker/ApplicationWorkerService.class */
public class ApplicationWorkerService<T extends Updateable> {
    private static final Logger LOG = LoggerFactory.getLogger(ApplicationWorkerService.class);
    private WorkerId workerId;
    private InetSocketAddress masterAddr;
    private WorkerState currentState;
    private NettyTransceiver nettyTransceiver;
    private IterativeReduceService masterService;
    private StartupConfiguration workerConf;
    private CanovaRecordReader recordParser;
    private ComputableWorker<T> computable;
    private Class<T> updateable;
    private Map<String, Integer> progressCounters;
    private ProgressReport progressReport;
    private long statusSleepTime;
    private long updateSleepTime;
    private ExecutorService updateExecutor;
    private Configuration conf;
    private long mWorkerTime;
    private long mWorkerExecutions;
    private long mWaits;
    private long mWaitTime;
    private long mUpdates;

    /* loaded from: input_file:org/deeplearning4j/iterativereduce/runtime/yarn/appworker/ApplicationWorkerService$PeriodicUpdateThread.class */
    class PeriodicUpdateThread implements Runnable {
        PeriodicUpdateThread() {
        }

        @Override // java.lang.Runnable
        public void run() {
            Thread.currentThread().setName("Periodic worker heartbeat thread");
            while (true) {
                ApplicationWorkerService.LOG.debug("Attemping to acquire state lock");
                synchronized (ApplicationWorkerService.this.currentState) {
                    if (WorkerState.RUNNING == ApplicationWorkerService.this.currentState) {
                        ApplicationWorkerService.LOG.debug("Worker is running, sending a progress report");
                        try {
                            ApplicationWorkerService.this.masterService.progress(ApplicationWorkerService.this.workerId, ApplicationWorkerService.this.createProgressReport());
                        } catch (AvroRemoteException e) {
                            ApplicationWorkerService.LOG.warn("Encountered an exception while heartbeating to master", e);
                        }
                    }
                }
                try {
                    ApplicationWorkerService.LOG.debug("Thread " + Thread.currentThread().getName() + " is going to sleep for " + ApplicationWorkerService.this.statusSleepTime);
                    Thread.sleep(ApplicationWorkerService.this.statusSleepTime);
                } catch (InterruptedException e2) {
                    ApplicationWorkerService.LOG.warn("Interrupted while sleeping on progress report");
                    return;
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/iterativereduce/runtime/yarn/appworker/ApplicationWorkerService$WorkerState.class */
    public enum WorkerState {
        NONE,
        STARTED,
        RUNNING,
        WAITING,
        UPDATE
    }

    public ApplicationWorkerService(String str, InetSocketAddress inetSocketAddress, CanovaRecordReader canovaRecordReader, ComputableWorker<T> computableWorker, Class<T> cls, Configuration configuration) {
        this.statusSleepTime = 2000L;
        this.updateSleepTime = 1000L;
        this.workerId = Utils.createWorkerId(str);
        this.currentState = WorkerState.NONE;
        this.masterAddr = inetSocketAddress;
        this.recordParser = canovaRecordReader;
        this.computable = computableWorker;
        this.updateable = cls;
        this.progressCounters = new HashMap();
        this.conf = configuration;
    }

    public ApplicationWorkerService(String str, InetSocketAddress inetSocketAddress, CanovaRecordReader canovaRecordReader, ComputableWorker<T> computableWorker, Class<T> cls) {
        this(str, inetSocketAddress, canovaRecordReader, computableWorker, cls, new Configuration());
    }

    public int run() {
        Thread.currentThread().setName("ApplicationWorkerService Thread - " + Utils.getWorkerId(this.workerId));
        if (!initializeService()) {
            return -1;
        }
        LOG.info("Worker " + Utils.getWorkerId(this.workerId) + " initialized");
        try {
            this.recordParser.initialize(Utils.getSplit(this.workerConf.getSplit()));
        } catch (IOException e) {
            e.printStackTrace();
        } catch (InterruptedException e2) {
            e2.printStackTrace();
        }
        LOG.debug("Launching periodic update thread");
        this.updateExecutor = Executors.newSingleThreadExecutor();
        this.updateExecutor.execute(new PeriodicUpdateThread());
        this.currentState = WorkerState.STARTED;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        this.computable.setRecordReader(this.recordParser);
        for (int i4 = 0; i4 < this.workerConf.getIterations().intValue(); i4++) {
            LOG.debug("Beginning iteration " + (i4 + 1) + "/" + this.workerConf.getIterations());
            synchronized (this.currentState) {
                this.currentState = WorkerState.RUNNING;
            }
            i++;
            int i5 = i2 + 1;
            synchronized (this.progressCounters) {
                this.progressCounters.put("countTotal", Integer.valueOf(i));
                this.progressCounters.put("countCurrent", Integer.valueOf(i5));
                this.progressCounters.put("currentIteration", Integer.valueOf(i4));
            }
            long currentTimeMillis = System.currentTimeMillis();
            T compute = this.computable.compute();
            this.mWorkerExecutions++;
            this.mWorkerTime += System.currentTimeMillis() - currentTimeMillis;
            try {
                synchronized (this.currentState) {
                    ByteBuffer bytes = compute.toBytes();
                    bytes.rewind();
                    LOG.info("Sending an update to master");
                    this.currentState = WorkerState.UPDATE;
                    if (!this.masterService.update(this.workerId, bytes)) {
                        LOG.warn("The master rejected our update");
                    }
                    this.mUpdates++;
                }
                try {
                    LOG.info("Completed a batch, waiting on an update from master");
                    int waitOnMasterUpdate = waitOnMasterUpdate(i3);
                    try {
                        ByteBuffer fetch = this.masterService.fetch(this.workerId, waitOnMasterUpdate);
                        fetch.rewind();
                        T newInstance = this.updateable.newInstance();
                        newInstance.fromBytes(fetch);
                        this.computable.update(newInstance);
                        i3 = waitOnMasterUpdate;
                        LOG.info("Requested to fetch an update from master, workerId=" + Utils.getWorkerId(this.workerId) + ", requestedUpdatedId=" + waitOnMasterUpdate + ", lastUpdate=" + i3 + ", responseLength=" + fetch.limit());
                        i2 = 0;
                    } catch (Exception e3) {
                        LOG.error("Got exception while processing update from master", e3);
                        return -1;
                    } catch (AvroRemoteException e4) {
                        LOG.error("Got exception while fetching an update from master", e4);
                        return -1;
                    }
                } catch (AvroRemoteException e5) {
                    LOG.error("Got an error while waiting on updates from master", e5);
                    return -1;
                } catch (InterruptedException e6) {
                    LOG.warn("Interrupted while waiting on master", e6);
                    return -1;
                }
            } catch (AvroRemoteException e7) {
                LOG.error("Unable to send update message to master", e7);
                return -1;
            }
        }
        reportMetrics();
        T results = this.computable.getResults();
        if (results != null) {
            try {
                LOG.info("Sending final update to master");
                this.masterService.update(this.workerId, results.toBytes());
            } catch (AvroRemoteException e8) {
                LOG.warn("Failed to send final update to master", e8);
            }
        }
        LOG.info("Completed processing, notfiying master that we're done");
        this.masterService.complete(this.workerId, createProgressReport());
        try {
            Thread.sleep(1000L);
        } catch (InterruptedException e9) {
        }
        this.nettyTransceiver.close();
        this.updateExecutor.shutdownNow();
        LOG.debug("Returning with code 0");
        return 0;
    }

    private boolean initializeService() {
        try {
            this.nettyTransceiver = new NettyTransceiver(this.masterAddr);
            this.masterService = (IterativeReduceService) SpecificRequestor.getClient(IterativeReduceService.class, this.nettyTransceiver);
            LOG.info("Connected to master via NettyTransiever at " + this.masterAddr);
            return getConfiguration();
        } catch (IOException e) {
            LOG.error("Unable to connect to master at " + this.masterAddr);
            return false;
        }
    }

    private boolean getConfiguration() {
        try {
            LOG.info("Checking in and downloading configuration from master");
            this.workerConf = this.masterService.startup(this.workerId);
            LOG.info("Received startup configuration from master, fileSplit=[" + ((Object) this.workerConf.getSplit().getPath()) + ", " + this.workerConf.getSplit().getOffset() + ", " + this.workerConf.getSplit().getLength() + "], batchSize=" + this.workerConf.getBatchSize() + ", iterations=" + this.workerConf.getIterations());
            Utils.mergeConfigs(this.workerConf, this.conf);
            this.computable.setup(this.conf);
            return true;
        } catch (AvroRemoteException e) {
            if (e instanceof ServiceError) {
                LOG.error("Unable to call startup(): " + ((Object) ((ServiceError) e).getDescription()), e);
                return false;
            }
            LOG.error("Unable to call startup()", e);
            return false;
        }
    }

    private int waitOnMasterUpdate(int i) throws InterruptedException, AvroRemoteException {
        long currentTimeMillis = System.currentTimeMillis();
        long j = 0;
        while (true) {
            int waiting = this.masterService.waiting(this.workerId, i, j);
            if (waiting >= 0) {
                this.mWaitTime += j;
                return waiting;
            }
            synchronized (this.currentState) {
                this.currentState = WorkerState.WAITING;
            }
            Thread.sleep(this.updateSleepTime);
            j = System.currentTimeMillis() - currentTimeMillis;
            LOG.info("Waiting on update from master with lastID " + i + " for " + j + "ms");
            this.mWaits++;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public ProgressReport createProgressReport() {
        if (this.progressReport == null) {
            this.progressReport = new ProgressReport();
            this.progressReport.setWorkerId(this.workerId);
        }
        HashMap hashMap = new HashMap();
        synchronized (this.progressCounters) {
            for (Map.Entry<String, Integer> entry : this.progressCounters.entrySet()) {
                hashMap.put(entry.getKey(), String.valueOf(entry.getValue()));
            }
        }
        this.progressReport.setReport(hashMap);
        if (LOG.isDebugEnabled()) {
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append("Created a progress report");
            stringBuffer.append(", workerId=").append(Utils.getWorkerId(this.progressReport.getWorkerId()));
            for (Map.Entry<CharSequence, CharSequence> entry2 : this.progressReport.getReport().entrySet()) {
                stringBuffer.append(", ").append(entry2.getKey()).append("=").append(entry2.getValue());
            }
            LOG.debug(stringBuffer.toString());
        }
        return this.progressReport;
    }

    private void reportMetrics() {
        HashMap hashMap = new HashMap();
        hashMap.put("ComputableWorkerTime", Long.valueOf(this.mWorkerTime));
        hashMap.put("ComputableWorkerExecutions", Long.valueOf(this.mWorkerExecutions));
        hashMap.put("WaitCount", Long.valueOf(this.mWaits));
        hashMap.put("WaitTime", Long.valueOf(this.mWaitTime));
        hashMap.put("UpdatesSent", Long.valueOf(this.mUpdates));
        this.masterService.metricsReport(this.workerId, hashMap);
    }
}
