package org.deeplearning4j.iterativereduce.impl.multilayer;

import java.util.List;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.RecordReader;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.iterativereduce.impl.reader.RecordReaderDataSetIterator;
import org.deeplearning4j.iterativereduce.runtime.ComputableWorker;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.scaleout.api.ir.ParameterVectorUpdateable;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/impl/multilayer/WorkerNode.class */
public class WorkerNode implements ComputableWorker<ParameterVectorUpdateable>, DeepLearningConfigurable {
    private MultiLayerNetwork multiLayerNetwork;
    private static final Logger log = LoggerFactory.getLogger(WorkerNode.class);
    private RecordReader recordParser;
    private DataSetIterator hdfsDataSetIterator = null;
    private long totalRecordsProcessed = 0;
    private StopWatch totalRunTimeWatch = new StopWatch();
    private StopWatch batchWatch = new StopWatch();
    private int batchSize = 20;
    private int numberClasses = 2;
    private int labelIndex = -1;
    public static final String LABEL_INDEX = "org.deeplearning4j.labelindex";

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable compute() {
        log.info("Worker > Compute() -------------------------- ");
        if (this.hdfsDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) this.hdfsDataSetIterator.next();
            if (dataSet.getFeatures().rows() > 0) {
                log.info("Rows: " + dataSet.numExamples() + ", inputs: " + dataSet.numInputs() + ", " + dataSet);
                this.totalRecordsProcessed += dataSet.getFeatures().rows();
                this.batchWatch.reset();
                this.batchWatch.start();
                this.multiLayerNetwork.fit(dataSet);
                this.batchWatch.stop();
                log.info("Worker > Processed Total " + this.totalRecordsProcessed + ", Batch Time " + this.batchWatch.toString() + " Total Time " + this.totalRunTimeWatch.toString());
            } else {
                log.info("Worker > Idle pass, no records left to process");
            }
        } else {
            log.info("Worker > Idle pass, no records left to process");
        }
        return new ParameterVectorUpdateable(this.multiLayerNetwork.params());
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void setRecordReader(RecordReader recordReader) {
        this.recordParser = recordReader;
        this.hdfsDataSetIterator = new RecordReaderDataSetIterator(recordReader, null, this.batchSize, this.labelIndex, this.numberClasses);
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable compute(List<ParameterVectorUpdateable> list) {
        return compute();
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable getResults() {
        return new ParameterVectorUpdateable(this.multiLayerNetwork.params());
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void setup(Configuration configuration) {
        log.info("Worker-Conf: " + configuration.get("org.deeplearning4j.scaleout.multilayerconf"));
        MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(configuration.get("org.deeplearning4j.scaleout.multilayerconf"));
        this.batchSize = fromJson.getConf(0).getBatchSize();
        this.numberClasses = fromJson.getConf(fromJson.getConfs().size() - 1).getnOut();
        this.labelIndex = configuration.getInt(LABEL_INDEX, -1);
        if (this.labelIndex < 0) {
            throw new IllegalStateException("Illegal label index");
        }
        this.multiLayerNetwork = new MultiLayerNetwork(fromJson);
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void update(ParameterVectorUpdateable parameterVectorUpdateable) {
        this.multiLayerNetwork.setParameters(parameterVectorUpdateable.get());
    }

    public void setup(org.canova.api.conf.Configuration configuration) {
    }
}
