package org.deeplearning4j.spark.api.worker;

import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.stats.StatsCalculationHelper;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.class */
public class ExecuteWorkerMultiDataSetFlatMap<R extends TrainingResult> implements FlatMapFunction<Iterator<MultiDataSet>, R> {
    private final TrainingWorker<R> worker;

    public Iterator<R> call(Iterator<MultiDataSet> it) throws Exception {
        WorkerConfiguration dataConfiguration = this.worker.getDataConfiguration();
        boolean isCollectTrainingStats = dataConfiguration.isCollectTrainingStats();
        StatsCalculationHelper statsCalculationHelper = isCollectTrainingStats ? new StatsCalculationHelper() : null;
        if (isCollectTrainingStats) {
            statsCalculationHelper.logMethodStartTime();
        }
        if (!it.hasNext()) {
            if (isCollectTrainingStats) {
                statsCalculationHelper.logReturnTime();
            }
            return Collections.emptyIterator();
        }
        int batchSizePerWorker = dataConfiguration.getBatchSizePerWorker();
        int prefetchNumBatches = dataConfiguration.getPrefetchNumBatches();
        MultiDataSetIterator iteratorMultiDataSetIterator = new IteratorMultiDataSetIterator(it, batchSizePerWorker);
        if (prefetchNumBatches > 0) {
            iteratorMultiDataSetIterator = new AsyncMultiDataSetIterator(iteratorMultiDataSetIterator, prefetchNumBatches);
        }
        if (isCollectTrainingStats) {
            try {
                statsCalculationHelper.logInitialModelBefore();
            } catch (Throwable th) {
                Nd4j.getExecutioner().commit();
                if (iteratorMultiDataSetIterator instanceof AsyncMultiDataSetIterator) {
                    ((AsyncMultiDataSetIterator) iteratorMultiDataSetIterator).shutdown();
                }
                throw th;
            }
        }
        ComputationGraph initialModelGraph = this.worker.getInitialModelGraph();
        if (isCollectTrainingStats) {
            statsCalculationHelper.logInitialModelAfter();
        }
        int i = 0;
        int maxBatchesPerWorker = dataConfiguration.getMaxBatchesPerWorker() > 0 ? dataConfiguration.getMaxBatchesPerWorker() : Integer.MAX_VALUE;
        while (iteratorMultiDataSetIterator.hasNext()) {
            int i2 = i;
            i++;
            if (i2 >= maxBatchesPerWorker) {
                break;
            }
            if (isCollectTrainingStats) {
                statsCalculationHelper.logNextDataSetBefore();
            }
            MultiDataSet multiDataSet = (MultiDataSet) iteratorMultiDataSetIterator.next();
            if (isCollectTrainingStats) {
                statsCalculationHelper.logNextDataSetAfter(multiDataSet.getFeatures(0).size(0));
            }
            if (isCollectTrainingStats) {
                statsCalculationHelper.logProcessMinibatchBefore();
                Pair<R, SparkTrainingStats> processMinibatchWithStats = this.worker.processMinibatchWithStats(multiDataSet, initialModelGraph, !iteratorMultiDataSetIterator.hasNext());
                statsCalculationHelper.logProcessMinibatchAfter();
                if (processMinibatchWithStats != null) {
                    statsCalculationHelper.logReturnTime();
                    ((TrainingResult) processMinibatchWithStats.getFirst()).setStats(statsCalculationHelper.build((SparkTrainingStats) processMinibatchWithStats.getSecond()));
                    Iterator<R> it2 = Collections.singletonList(processMinibatchWithStats.getFirst()).iterator();
                    Nd4j.getExecutioner().commit();
                    if (iteratorMultiDataSetIterator instanceof AsyncMultiDataSetIterator) {
                        ((AsyncMultiDataSetIterator) iteratorMultiDataSetIterator).shutdown();
                    }
                    return it2;
                }
            } else {
                R processMinibatch = this.worker.processMinibatch(multiDataSet, initialModelGraph, !iteratorMultiDataSetIterator.hasNext());
                if (processMinibatch != null) {
                    Iterator<R> it3 = Collections.singletonList(processMinibatch).iterator();
                    Nd4j.getExecutioner().commit();
                    if (iteratorMultiDataSetIterator instanceof AsyncMultiDataSetIterator) {
                        ((AsyncMultiDataSetIterator) iteratorMultiDataSetIterator).shutdown();
                    }
                    return it3;
                }
            }
        }
        if (!isCollectTrainingStats) {
            Iterator<R> it4 = Collections.singletonList(this.worker.getFinalResult(initialModelGraph)).iterator();
            Nd4j.getExecutioner().commit();
            if (iteratorMultiDataSetIterator instanceof AsyncMultiDataSetIterator) {
                ((AsyncMultiDataSetIterator) iteratorMultiDataSetIterator).shutdown();
            }
            return it4;
        }
        statsCalculationHelper.logReturnTime();
        Pair<R, SparkTrainingStats> finalResultWithStats = this.worker.getFinalResultWithStats(initialModelGraph);
        ((TrainingResult) finalResultWithStats.getFirst()).setStats(statsCalculationHelper.build((SparkTrainingStats) finalResultWithStats.getSecond()));
        Iterator<R> it5 = Collections.singletonList(finalResultWithStats.getFirst()).iterator();
        Nd4j.getExecutioner().commit();
        if (iteratorMultiDataSetIterator instanceof AsyncMultiDataSetIterator) {
            ((AsyncMultiDataSetIterator) iteratorMultiDataSetIterator).shutdown();
        }
        return it5;
    }

    public ExecuteWorkerMultiDataSetFlatMap(TrainingWorker<R> trainingWorker) {
        this.worker = trainingWorker;
    }
}
