package org.deeplearning4j.spark.impl.paramavg;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.storage.StorageLevel;
import org.datavec.spark.util.BroadcastHadoopConfigHolder;
import org.deeplearning4j.api.loader.DataSetLoader;
import org.deeplearning4j.api.loader.MultiDataSetLoader;
import org.deeplearning4j.api.loader.impl.SerializedDataSetLoader;
import org.deeplearning4j.api.loader.impl.SerializedMultiDataSetLoader;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.StatsStorageRouterProvider;
import org.deeplearning4j.api.storage.StorageMetaData;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerMultiDataSetFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPDSFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPDSMDSFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPathFlatMap;
import org.deeplearning4j.spark.api.worker.ExecuteWorkerPathMDSFlatMap;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction;
import org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.base.Preconditions;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath", "trainingMasterUID"})
/* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.class */
public class ParameterAveragingTrainingMaster extends BaseTrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> implements TrainingMaster<ParameterAveragingTrainingResult, ParameterAveragingTrainingWorker> {
    private static final Logger log = LoggerFactory.getLogger(ParameterAveragingTrainingMaster.class);
    protected static final int COALESCE_THRESHOLD = 3;
    protected boolean saveUpdater;
    protected Integer numWorkers;
    protected int rddDataSetNumExamples;
    protected int averagingFrequency;
    protected int aggregationDepth;
    protected int prefetchNumBatches;
    protected int iterationCount;
    protected Collection<TrainingHook> trainingHookList;

    /* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster$Builder.class */
    public static class Builder {
        protected boolean saveUpdater;
        protected Integer numWorkers;
        protected int rddDataSetNumExamples;
        protected int batchSizePerWorker;
        protected int averagingFrequency;
        protected int aggregationDepth;
        protected int prefetchNumBatches;
        protected Repartition repartition;
        protected RepartitionStrategy repartitionStrategy;
        protected StorageLevel storageLevel;
        protected StorageLevel storageLevelStreams;
        protected RDDTrainingApproach rddTrainingApproach;
        protected String exportDirectory;
        protected Long rngSeed;
        protected Collection<TrainingHook> trainingHooks;
        protected boolean collectTrainingStats;

        public Builder trainingHooks(Collection<TrainingHook> collection) {
            this.trainingHooks = collection;
            return this;
        }

        public Builder trainingHooks(TrainingHook... trainingHookArr) {
            this.trainingHooks = Arrays.asList(trainingHookArr);
            return this;
        }

        public Builder(int i) {
            this(null, i);
        }

        public Builder(Integer num, int i) {
            this.batchSizePerWorker = 16;
            this.averagingFrequency = 5;
            this.aggregationDepth = 2;
            this.prefetchNumBatches = 0;
            this.repartition = Repartition.Always;
            this.repartitionStrategy = RepartitionStrategy.Balanced;
            this.storageLevel = StorageLevel.MEMORY_ONLY_SER();
            this.storageLevelStreams = StorageLevel.MEMORY_ONLY();
            this.rddTrainingApproach = RDDTrainingApproach.Export;
            this.exportDirectory = null;
            this.collectTrainingStats = false;
            Preconditions.checkArgument(num == null || num.intValue() > 0, "Invalid number of workers: " + num + " (must be >= 1)");
            Preconditions.checkArgument(i > 0, "Invalid rdd data set size: " + i + " (must be >= 1)");
            this.numWorkers = num;
            this.rddDataSetNumExamples = i;
        }

        public Builder batchSizePerWorker(int i) {
            this.batchSizePerWorker = i;
            return this;
        }

        public Builder averagingFrequency(int i) {
            Preconditions.checkArgument(i > 0, "Invalid input: averaging frequency must be >= 1");
            this.averagingFrequency = i;
            return this;
        }

        public Builder aggregationDepth(int i) {
            Preconditions.checkArgument(i > 0, "Invalid input: tree aggregation channels must be >= 1");
            this.aggregationDepth = i;
            return this;
        }

        public Builder workerPrefetchNumBatches(int i) {
            this.prefetchNumBatches = i;
            return this;
        }

        public Builder saveUpdater(boolean z) {
            this.saveUpdater = z;
            return this;
        }

        public Builder repartionData(Repartition repartition) {
            this.repartition = repartition;
            return this;
        }

        public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) {
            this.repartitionStrategy = repartitionStrategy;
            return this;
        }

        public Builder storageLevel(StorageLevel storageLevel) {
            this.storageLevel = storageLevel;
            return this;
        }

        public Builder storageLevelStreams(StorageLevel storageLevel) {
            this.storageLevelStreams = storageLevel;
            return this;
        }

        public Builder rddTrainingApproach(RDDTrainingApproach rDDTrainingApproach) {
            this.rddTrainingApproach = rDDTrainingApproach;
            return this;
        }

        public Builder exportDirectory(String str) {
            this.exportDirectory = str;
            return this;
        }

        public Builder rngSeed(long j) {
            this.rngSeed = Long.valueOf(j);
            return this;
        }

        public Builder collectTrainingStats(boolean z) {
            this.collectTrainingStats = z;
            return this;
        }

        public ParameterAveragingTrainingMaster build() {
            return new ParameterAveragingTrainingMaster(this);
        }
    }

    protected ParameterAveragingTrainingMaster() {
        this.iterationCount = 0;
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.rng = new Random();
    }

    protected ParameterAveragingTrainingMaster(Builder builder) {
        this.iterationCount = 0;
        this.saveUpdater = builder.saveUpdater;
        this.numWorkers = builder.numWorkers;
        this.rddDataSetNumExamples = builder.rddDataSetNumExamples;
        this.batchSizePerWorker = builder.batchSizePerWorker;
        this.averagingFrequency = builder.averagingFrequency;
        this.aggregationDepth = builder.aggregationDepth;
        this.prefetchNumBatches = builder.prefetchNumBatches;
        this.repartition = builder.repartition;
        this.repartitionStrategy = builder.repartitionStrategy;
        this.storageLevel = builder.storageLevel;
        this.storageLevelStreams = builder.storageLevelStreams;
        this.rddTrainingApproach = builder.rddTrainingApproach;
        this.exportDirectory = builder.exportDirectory;
        this.trainingHookList = builder.trainingHooks;
        this.collectTrainingStats = builder.collectTrainingStats;
        if (this.collectTrainingStats) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
        if (builder.rngSeed == null) {
            this.rng = new Random();
        } else {
            this.rng = new Random(builder.rngSeed.longValue());
        }
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
    }

    public ParameterAveragingTrainingMaster(boolean z, Integer num, int i, int i2, int i3, int i4) {
        this(z, num, i, i2, i3, 2, i4, Repartition.Always, RepartitionStrategy.Balanced, false);
    }

    public ParameterAveragingTrainingMaster(boolean z, Integer num, int i, int i2, int i3, int i4, int i5, Repartition repartition, RepartitionStrategy repartitionStrategy, boolean z2) {
        this(z, num, i, i2, i3, i4, i5, repartition, repartitionStrategy, StorageLevel.MEMORY_ONLY_SER(), z2);
    }

    public ParameterAveragingTrainingMaster(boolean z, Integer num, int i, int i2, int i3, int i4, int i5, Repartition repartition, RepartitionStrategy repartitionStrategy, StorageLevel storageLevel, boolean z2) {
        this.iterationCount = 0;
        Preconditions.checkArgument(num.intValue() > 0, "Invalid number of workers: " + num + " (must be >= 1)");
        Preconditions.checkArgument(i > 0, "Invalid rdd data set size: " + i + " (must be >= 1)");
        Preconditions.checkArgument(i3 > 0, "Invalid input: averaging frequency must be >= 1");
        Preconditions.checkArgument(i4 > 0, "Invalid input: tree aggregation channels must be >= 1");
        this.saveUpdater = z;
        this.numWorkers = num;
        this.rddDataSetNumExamples = i;
        this.batchSizePerWorker = i2;
        this.averagingFrequency = i3;
        this.aggregationDepth = i4;
        this.prefetchNumBatches = i5;
        this.collectTrainingStats = z2;
        this.repartition = repartition;
        this.repartitionStrategy = repartitionStrategy;
        this.storageLevel = storageLevel;
        if (z2) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.rng = new Random();
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void removeHook(TrainingHook trainingHook) {
        if (this.trainingHookList == null) {
            return;
        }
        this.trainingHookList.remove(trainingHook);
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void addHook(TrainingHook trainingHook) {
        if (this.trainingHookList == null) {
            this.trainingHookList = new ArrayList();
        }
        this.trainingHookList.add(trainingHook);
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public String toJson() {
        try {
            return getJsonMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing JSON representation for ParameterAveragingTrainingMaster", e);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public String toYaml() {
        try {
            return getYamlMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing YAML representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public static ParameterAveragingTrainingMaster fromJson(String str) {
        try {
            return (ParameterAveragingTrainingMaster) getJsonMapper().readValue(str, ParameterAveragingTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse JSON", e);
        }
    }

    public static ParameterAveragingTrainingMaster fromYaml(String str) {
        try {
            return (ParameterAveragingTrainingMaster) getYamlMapper().readValue(str, ParameterAveragingTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse YAML", e);
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer sparkDl4jMultiLayer) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkDl4jMultiLayer.getNetwork().getLayerWiseConfigurations(), sparkDl4jMultiLayer.getNetwork().params(), sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray());
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        Broadcast broadcast = sparkDl4jMultiLayer.getSparkContext().broadcast(netBroadcastTuple);
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new ParameterAveragingTrainingWorker(broadcast, this.saveUpdater, new WorkerConfiguration(false, this.rddDataSetNumExamples, this.batchSizePerWorker, this.averagingFrequency, this.prefetchNumBatches, this.collectTrainingStats), this.trainingHookList, this.listeners, getRouterProvider());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph sparkComputationGraph) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkComputationGraph.getNetwork().getConfiguration(), sparkComputationGraph.getNetwork().params(), sparkComputationGraph.getNetwork().getUpdater().getStateViewArray());
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        Broadcast broadcast = sparkComputationGraph.getSparkContext().broadcast(netBroadcastTuple);
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new ParameterAveragingTrainingWorker(broadcast, this.saveUpdater, new WorkerConfiguration(true, this.rddDataSetNumExamples, this.batchSizePerWorker, this.averagingFrequency, this.prefetchNumBatches, this.collectTrainingStats), this.trainingHookList, this.listeners, getRouterProvider());
    }

    protected int numObjectsEachWorker(int i) {
        return (this.batchSizePerWorker * this.averagingFrequency) / i;
    }

    protected int getNumDataSetObjectsPerSplit(int i) {
        int intValue;
        if (i == 1) {
            intValue = this.numWorkers.intValue() * this.batchSizePerWorker * this.averagingFrequency;
        } else {
            int numObjectsEachWorker = numObjectsEachWorker(i);
            if (numObjectsEachWorker < 1) {
                numObjectsEachWorker = 1;
            }
            intValue = numObjectsEachWorker * this.numWorkers.intValue();
        }
        return intValue;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(sparkDl4jMultiLayer, javaRDD);
        } else {
            executeTrainingPathsHelper(sparkDl4jMultiLayer, null, exportIfRequired(sparkDl4jMultiLayer.getSparkContext(), javaRDD), new SerializedDataSetLoader(), null, this.batchSizePerWorker);
        }
    }

    protected <T, Repr extends JavaRDDLike<T, Repr>> long getTotalDataSetObjectCount(JavaRDDLike<T, Repr> javaRDDLike) {
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDDLike.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        return count;
    }

    protected <T, Repr> JavaPairRDD<T, Repr>[] getSplitRDDs(JavaPairRDD<T, Repr> javaPairRDD, int i) {
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit(this.rddDataSetNumExamples);
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaPairRDD<T, Repr>[] balancedRandomSplit = SparkUtils.balancedRandomSplit(i, numDataSetObjectsPerSplit, javaPairRDD, this.rng.nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        return balancedRandomSplit;
    }

    protected <T> JavaRDD<T>[] getSplitRDDs(JavaRDD<T> javaRDD, int i, int i2) {
        int numDataSetObjectsPerSplit = getNumDataSetObjectsPerSplit(i2);
        if (this.collectTrainingStats) {
            this.stats.logSplitStart();
        }
        JavaRDD<T>[] balancedRandomSplit = SparkUtils.balancedRandomSplit(i, numDataSetObjectsPerSplit, javaRDD, this.rng.nextLong());
        if (this.collectTrainingStats) {
            this.stats.logSplitEnd();
        }
        return balancedRandomSplit;
    }

    protected void executeTrainingDirect(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        JavaRDD<DataSet>[] splitRDDs = getSplitRDDs(javaRDD, (int) totalDataSetObjectCount, this.rddDataSetNumExamples);
        int i = 1;
        for (JavaRDD<DataSet> javaRDD2 : splitRDDs) {
            int i2 = i;
            i++;
            doIteration(sparkDl4jMultiLayer, javaRDD2, i2, splitRDDs.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, DataSetLoader dataSetLoader, MultiDataSetLoader multiDataSetLoader) {
        executeTrainingPathsHelper(sparkDl4jMultiLayer, sparkComputationGraph, javaRDD, dataSetLoader, multiDataSetLoader, this.rddDataSetNumExamples);
    }

    protected void executeTrainingPathsHelper(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, DataSetLoader dataSetLoader, MultiDataSetLoader multiDataSetLoader, int i) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        JavaRDD<String>[] splitRDDs = getSplitRDDs(javaRDD, (int) totalDataSetObjectCount, i);
        int i2 = 1;
        for (JavaRDD<String> javaRDD2 : splitRDDs) {
            int i3 = i2;
            i2++;
            doIterationPaths(sparkDl4jMultiLayer, sparkComputationGraph, javaRDD2, i3, splitRDDs.length, i, dataSetLoader, multiDataSetLoader);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTraining(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        executeTrainingMDS(sparkComputationGraph, javaRDD.map(new DataSetToMultiDataSetFn()));
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void executeTrainingMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        if (this.numWorkers == null) {
            this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(sparkComputationGraph, javaRDD);
        } else {
            executeTrainingPathsHelper(null, sparkComputationGraph, exportIfRequiredMDS(sparkComputationGraph.getSparkContext(), javaRDD), null, new SerializedMultiDataSetLoader(), this.batchSizePerWorker);
        }
    }

    protected void executeTrainingDirect(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        JavaRDD<MultiDataSet>[] splitRDDs = getSplitRDDs(javaRDD, (int) totalDataSetObjectCount, this.rddDataSetNumExamples);
        int i = 1;
        for (JavaRDD<MultiDataSet> javaRDD2 : splitRDDs) {
            int i2 = i;
            i++;
            doIteration(sparkComputationGraph, javaRDD2, i2, splitRDDs.length);
        }
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void setCollectTrainingStats(boolean z) {
        this.collectTrainingStats = z;
        if (!z) {
            this.stats = null;
        } else if (this.stats == null) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public boolean getIsCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public SparkTrainingStats getTrainingStats() {
        if (this.stats != null) {
            return this.stats.build();
        }
        return null;
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void setListeners(Collection<TrainingListener> collection) {
        setListeners(null, collection);
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public void setListeners(StatsStorageRouter statsStorageRouter, Collection<TrainingListener> collection) {
        this.statsStorage = statsStorageRouter;
        this.listeners = collection == null ? null : new ArrayList(collection);
    }

    protected void doIteration(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, null, repartition.mapPartitions(new ExecuteWorkerFlatMap(getWorkerInstance(sparkDl4jMultiLayer))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    @Deprecated
    protected void doIterationPDS(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<PortableDataStream> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartition.mapPartitions(sparkDl4jMultiLayer != null ? new ExecuteWorkerPDSFlatMap(getWorkerInstance(sparkDl4jMultiLayer)) : new ExecuteWorkerPDSFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i, int i2, int i3, DataSetLoader dataSetLoader, MultiDataSetLoader multiDataSetLoader) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(i3), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        JavaSparkContext sparkContext = sparkDl4jMultiLayer != null ? sparkDl4jMultiLayer.getSparkContext() : sparkComputationGraph.getSparkContext();
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartition.mapPartitions(sparkDl4jMultiLayer != null ? dataSetLoader != null ? new ExecuteWorkerPathFlatMap(getWorkerInstance(sparkDl4jMultiLayer), dataSetLoader, BroadcastHadoopConfigHolder.get(sparkContext)) : new ExecuteWorkerPathMDSFlatMap(getWorkerInstance(sparkDl4jMultiLayer), multiDataSetLoader, BroadcastHadoopConfigHolder.get(sparkContext)) : dataSetLoader != null ? new ExecuteWorkerPathFlatMap(getWorkerInstance(sparkComputationGraph), dataSetLoader, BroadcastHadoopConfigHolder.get(sparkContext)) : new ExecuteWorkerPathMDSFlatMap(getWorkerInstance(sparkComputationGraph), multiDataSetLoader, BroadcastHadoopConfigHolder.get(sparkContext))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIteration(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = javaRDD.partitions().size();
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new ExecuteWorkerMultiDataSetFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationPDS_MDS(SparkComputationGraph sparkComputationGraph, JavaRDD<PortableDataStream> javaRDD, int i, int i2) {
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, averagingFreq={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), Integer.valueOf(this.averagingFrequency), this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        JavaRDD repartition = SparkUtils.repartition(javaRDD, this.repartition, this.repartitionStrategy, numObjectsEachWorker(this.rddDataSetNumExamples), this.numWorkers.intValue());
        int size = repartition.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartition.mapPartitions(new ExecuteWorkerPDSMDSFlatMap(getWorkerInstance(sparkComputationGraph))), i, i2);
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void processResults(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<ParameterAveragingTrainingResult> javaRDD, int i, int i2) {
        if (this.collectTrainingStats) {
            this.stats.logAggregateStartTime();
        }
        ParameterAveragingAggregationTuple parameterAveragingAggregationTuple = (ParameterAveragingAggregationTuple) javaRDD.treeAggregate((Object) null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction(), this.aggregationDepth);
        INDArray parametersSum = parameterAveragingAggregationTuple.getParametersSum();
        int aggregationsCount = parameterAveragingAggregationTuple.getAggregationsCount();
        SparkTrainingStats sparkTrainingStats = parameterAveragingAggregationTuple.getSparkTrainingStats();
        if (this.collectTrainingStats) {
            this.stats.logAggregationEndTime();
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterStart();
        }
        if (parametersSum != null) {
            parametersSum.divi(Integer.valueOf(aggregationsCount));
            INDArray updaterStateSum = parameterAveragingAggregationTuple.getUpdaterStateSum();
            if (updaterStateSum != null) {
                updaterStateSum.divi(Integer.valueOf(aggregationsCount));
            }
            if (sparkDl4jMultiLayer != null) {
                MultiLayerNetwork network = sparkDl4jMultiLayer.getNetwork();
                network.setParameters(parametersSum);
                if (updaterStateSum != null) {
                    network.getUpdater().setStateViewArray((Trainable) null, updaterStateSum, false);
                }
                sparkDl4jMultiLayer.setScore(parameterAveragingAggregationTuple.getScoreSum() / parameterAveragingAggregationTuple.getAggregationsCount());
            } else {
                ComputationGraph network2 = sparkComputationGraph.getNetwork();
                network2.setParams(parametersSum);
                if (updaterStateSum != null) {
                    network2.getUpdater().setStateViewArray(updaterStateSum);
                }
                sparkComputationGraph.setScore(parameterAveragingAggregationTuple.getScoreSum() / parameterAveragingAggregationTuple.getAggregationsCount());
            }
        } else {
            log.info("Skipping imbalanced split with no data for all executors");
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
            this.stats.addWorkerStats(sparkTrainingStats);
        }
        if (this.statsStorage != null) {
            Collection<StorageMetaData> listenerMetaData = parameterAveragingAggregationTuple.getListenerMetaData();
            if (listenerMetaData != null && !listenerMetaData.isEmpty()) {
                this.statsStorage.putStorageMetaData(listenerMetaData);
            }
            Collection<Persistable> listenerStaticInfo = parameterAveragingAggregationTuple.getListenerStaticInfo();
            if (listenerStaticInfo != null && !listenerStaticInfo.isEmpty()) {
                this.statsStorage.putStaticInfo(listenerStaticInfo);
            }
            Collection<Persistable> listenerUpdates = parameterAveragingAggregationTuple.getListenerUpdates();
            if (listenerUpdates != null && !listenerUpdates.isEmpty()) {
                this.statsStorage.putUpdate(listenerUpdates);
            }
        }
        Nd4j.getExecutioner().commit();
        log.info("Completed training of split {} of {}", Integer.valueOf(i), Integer.valueOf(i2));
        if (parametersSum != null) {
            if (sparkDl4jMultiLayer != null) {
                MultiLayerConfiguration layerWiseConfigurations = sparkDl4jMultiLayer.getNetwork().getLayerWiseConfigurations();
                layerWiseConfigurations.setIterationCount(layerWiseConfigurations.getIterationCount() + this.averagingFrequency);
            } else {
                ComputationGraphConfiguration configuration = sparkComputationGraph.getNetwork().getConfiguration();
                configuration.setIterationCount(configuration.getIterationCount() + this.averagingFrequency);
            }
        }
    }

    protected StatsStorageRouterProvider getRouterProvider() {
        if (this.statsStorage == null) {
            return null;
        }
        return new VanillaStatsStorageRouterProvider();
    }

    public boolean isSaveUpdater() {
        return this.saveUpdater;
    }

    public Integer getNumWorkers() {
        return this.numWorkers;
    }

    public int getRddDataSetNumExamples() {
        return this.rddDataSetNumExamples;
    }

    public int getAveragingFrequency() {
        return this.averagingFrequency;
    }

    public int getAggregationDepth() {
        return this.aggregationDepth;
    }

    public int getPrefetchNumBatches() {
        return this.prefetchNumBatches;
    }

    public int getIterationCount() {
        return this.iterationCount;
    }

    public Collection<TrainingHook> getTrainingHookList() {
        return this.trainingHookList;
    }

    public void setSaveUpdater(boolean z) {
        this.saveUpdater = z;
    }

    public void setNumWorkers(Integer num) {
        this.numWorkers = num;
    }

    public void setRddDataSetNumExamples(int i) {
        this.rddDataSetNumExamples = i;
    }

    public void setAveragingFrequency(int i) {
        this.averagingFrequency = i;
    }

    public void setAggregationDepth(int i) {
        this.aggregationDepth = i;
    }

    public void setPrefetchNumBatches(int i) {
        this.prefetchNumBatches = i;
    }

    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    public void setTrainingHookList(Collection<TrainingHook> collection) {
        this.trainingHookList = collection;
    }

    public String toString() {
        return "ParameterAveragingTrainingMaster(saveUpdater=" + isSaveUpdater() + ", numWorkers=" + getNumWorkers() + ", rddDataSetNumExamples=" + getRddDataSetNumExamples() + ", averagingFrequency=" + getAveragingFrequency() + ", aggregationDepth=" + getAggregationDepth() + ", prefetchNumBatches=" + getPrefetchNumBatches() + ", iterationCount=" + getIterationCount() + ", trainingHookList=" + getTrainingHookList() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParameterAveragingTrainingMaster)) {
            return false;
        }
        ParameterAveragingTrainingMaster parameterAveragingTrainingMaster = (ParameterAveragingTrainingMaster) obj;
        if (!parameterAveragingTrainingMaster.canEqual(this) || isSaveUpdater() != parameterAveragingTrainingMaster.isSaveUpdater()) {
            return false;
        }
        Integer numWorkers = getNumWorkers();
        Integer numWorkers2 = parameterAveragingTrainingMaster.getNumWorkers();
        if (numWorkers == null) {
            if (numWorkers2 != null) {
                return false;
            }
        } else if (!numWorkers.equals(numWorkers2)) {
            return false;
        }
        if (getRddDataSetNumExamples() != parameterAveragingTrainingMaster.getRddDataSetNumExamples() || getAveragingFrequency() != parameterAveragingTrainingMaster.getAveragingFrequency() || getAggregationDepth() != parameterAveragingTrainingMaster.getAggregationDepth() || getPrefetchNumBatches() != parameterAveragingTrainingMaster.getPrefetchNumBatches()) {
            return false;
        }
        Collection<TrainingHook> trainingHookList = getTrainingHookList();
        Collection<TrainingHook> trainingHookList2 = parameterAveragingTrainingMaster.getTrainingHookList();
        return trainingHookList == null ? trainingHookList2 == null : trainingHookList.equals(trainingHookList2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParameterAveragingTrainingMaster;
    }

    public int hashCode() {
        int i = (1 * 59) + (isSaveUpdater() ? 79 : 97);
        Integer numWorkers = getNumWorkers();
        int hashCode = (((((((((i * 59) + (numWorkers == null ? 43 : numWorkers.hashCode())) * 59) + getRddDataSetNumExamples()) * 59) + getAveragingFrequency()) * 59) + getAggregationDepth()) * 59) + getPrefetchNumBatches();
        Collection<TrainingHook> trainingHookList = getTrainingHookList();
        return (hashCode * 59) + (trainingHookList == null ? 43 : trainingHookList.hashCode());
    }
}
