package org.deeplearning4j.iterativereduce.impl.single;

import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.util.ToolRunner;
import org.deeplearning4j.iterativereduce.impl.ParameterVectorUpdateable;
import org.deeplearning4j.iterativereduce.runtime.ComputableWorker;
import org.deeplearning4j.iterativereduce.runtime.io.RecordParser;
import org.deeplearning4j.iterativereduce.runtime.io.TextRecordParser;
import org.deeplearning4j.iterativereduce.runtime.yarn.appworker.ApplicationWorker;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/impl/single/WorkerNode.class */
public class WorkerNode implements ComputableWorker<ParameterVectorUpdateable>, DeepLearningConfigurable {
    private static final Logger LOG = LoggerFactory.getLogger(WorkerNode.class);
    private Layer neuralNetwork;
    private RecordParser recordParser;

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable compute() {
        while (this.recordParser.hasMoreRecords()) {
            this.neuralNetwork.fit(this.recordParser.nextRecord().getFeatureMatrix());
        }
        return new ParameterVectorUpdateable(this.neuralNetwork.params());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable compute(List<ParameterVectorUpdateable> list) {
        return compute();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public ParameterVectorUpdateable getResults() {
        return new ParameterVectorUpdateable(this.neuralNetwork.params());
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void setRecordParser(RecordParser recordParser) {
        this.recordParser = recordParser;
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void setup(Configuration configuration) {
        NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(configuration.get("org.deeplearning4j.scaleout.neuralnetconf"));
        this.neuralNetwork = fromJson.getLayerFactory().create(fromJson);
    }

    @Override // org.deeplearning4j.iterativereduce.runtime.ComputableWorker
    public void update(ParameterVectorUpdateable parameterVectorUpdateable) {
        this.neuralNetwork.setParams(parameterVectorUpdateable.get());
    }

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new ApplicationWorker(new TextRecordParser(), new WorkerNode(), ParameterVectorUpdateable.class), strArr);
    }

    public void setup(org.deeplearning4j.nn.conf.Configuration configuration) {
    }
}
