package org.deeplearning4j.spark.impl.paramavg;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingWorkerStats;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.class */
public class ParameterAveragingTrainingWorker implements TrainingWorker<ParameterAveragingTrainingResult> {
    private final Broadcast<NetBroadcastTuple> broadcast;
    private final boolean saveUpdater;
    private final WorkerConfiguration configuration;
    private ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper stats = null;

    public ParameterAveragingTrainingWorker(Broadcast<NetBroadcastTuple> broadcast, boolean z, WorkerConfiguration workerConfiguration) {
        this.broadcast = broadcast;
        this.saveUpdater = z;
        this.configuration = workerConfiguration;
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public MultiLayerNetwork getInitialModel() {
        if (this.configuration.isCollectTrainingStats()) {
            this.stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
        }
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logBroadcastGetValueStart();
        }
        NetBroadcastTuple netBroadcastTuple = (NetBroadcastTuple) this.broadcast.getValue();
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logBroadcastGetValueEnd();
        }
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(netBroadcastTuple.getConfiguration());
        multiLayerNetwork.init(netBroadcastTuple.getParameters().unsafeDuplication(), false);
        if (netBroadcastTuple.getUpdaterState() != null) {
            multiLayerNetwork.setUpdater(new MultiLayerUpdater(multiLayerNetwork, netBroadcastTuple.getUpdaterState().unsafeDuplication()));
        }
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logInitEnd();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueueBlocking();
        }
        return multiLayerNetwork;
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public ComputationGraph getInitialModelGraph() {
        if (this.configuration.isCollectTrainingStats()) {
            this.stats = new ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper();
        }
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logBroadcastGetValueStart();
        }
        NetBroadcastTuple netBroadcastTuple = (NetBroadcastTuple) this.broadcast.getValue();
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logBroadcastGetValueEnd();
        }
        ComputationGraph computationGraph = new ComputationGraph(netBroadcastTuple.getGraphConfiguration());
        computationGraph.init(netBroadcastTuple.getParameters().unsafeDuplication(), false);
        if (netBroadcastTuple.getUpdaterState() != null) {
            computationGraph.setUpdater(new ComputationGraphUpdater(computationGraph, netBroadcastTuple.getUpdaterState().unsafeDuplication()));
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueueBlocking();
        }
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logInitEnd();
        }
        return computationGraph;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, MultiLayerNetwork multiLayerNetwork, boolean z) {
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logFitStart();
        }
        multiLayerNetwork.fit(dataSet);
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logFitEnd(dataSet.numExamples());
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueueBlocking();
        }
        if (z) {
            return getFinalResult(multiLayerNetwork);
        }
        return null;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public ParameterAveragingTrainingResult processMinibatch(DataSet dataSet, ComputationGraph computationGraph, boolean z) {
        return processMinibatch(ComputationGraphUtil.toMultiDataSet(dataSet), computationGraph, z);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public ParameterAveragingTrainingResult processMinibatch(MultiDataSet multiDataSet, ComputationGraph computationGraph, boolean z) {
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logFitStart();
        }
        computationGraph.fit(multiDataSet);
        if (this.configuration.isCollectTrainingStats()) {
            this.stats.logFitEnd(multiDataSet.getFeatures(0).size(0));
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueueBlocking();
        }
        if (z) {
            return getFinalResult(computationGraph);
        }
        return null;
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork multiLayerNetwork, boolean z) {
        ParameterAveragingTrainingResult processMinibatch = processMinibatch(dataSet, multiLayerNetwork, z);
        if (processMinibatch == null) {
            return null;
        }
        return new Pair<>(processMinibatch, this.stats != null ? this.stats.build() : null);
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph computationGraph, boolean z) {
        return processMinibatchWithStats(ComputationGraphUtil.toMultiDataSet(dataSet), computationGraph, z);
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> processMinibatchWithStats(MultiDataSet multiDataSet, ComputationGraph computationGraph, boolean z) {
        ParameterAveragingTrainingResult processMinibatch = processMinibatch(multiDataSet, computationGraph, z);
        if (processMinibatch == null) {
            return null;
        }
        return new Pair<>(processMinibatch, this.stats != null ? this.stats.build() : null);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork multiLayerNetwork) {
        Updater updater;
        INDArray iNDArray = null;
        if (this.saveUpdater && (updater = multiLayerNetwork.getUpdater()) != null) {
            iNDArray = updater.getStateViewArray();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueueBlocking();
        }
        return new ParameterAveragingTrainingResult(multiLayerNetwork.params(), iNDArray, multiLayerNetwork.score());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public ParameterAveragingTrainingResult getFinalResult(ComputationGraph computationGraph) {
        ComputationGraphUpdater updater;
        INDArray iNDArray = null;
        if (this.saveUpdater && (updater = computationGraph.getUpdater()) != null) {
            iNDArray = updater.getStateViewArray();
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueueBlocking();
        }
        return new ParameterAveragingTrainingResult(computationGraph.params(), iNDArray, computationGraph.score());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public ParameterAveragingTrainingResult getFinalResultNoData() {
        return new ParameterAveragingTrainingResult(null, null, 0.0d, null);
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultNoDataWithStats() {
        return new Pair<>(new ParameterAveragingTrainingResult(null, null, 0.0d, null), (Object) null);
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork multiLayerNetwork) {
        ParameterAveragingTrainingResult finalResult = getFinalResult(multiLayerNetwork);
        if (finalResult == null) {
            return null;
        }
        return new Pair<>(finalResult, this.stats != null ? this.stats.build() : null);
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public Pair<ParameterAveragingTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph computationGraph) {
        ParameterAveragingTrainingResult finalResult = getFinalResult(computationGraph);
        if (finalResult == null) {
            return null;
        }
        return new Pair<>(finalResult, this.stats != null ? this.stats.build() : null);
    }

    @Override // org.deeplearning4j.spark.api.TrainingWorker
    public WorkerConfiguration getDataConfiguration() {
        return this.configuration;
    }
}
