package org.deeplearning4j.spark.impl.paramavg.aggregator;

import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementCombineFunction.class */
public class ParameterAveragingElementCombineFunction implements Function2<ParameterAveragingAggregationTuple, ParameterAveragingAggregationTuple, ParameterAveragingAggregationTuple> {
    public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple parameterAveragingAggregationTuple, ParameterAveragingAggregationTuple parameterAveragingAggregationTuple2) throws Exception {
        UpdaterAggregator updaterAggregator;
        ComputationGraphUpdater.Aggregator aggregator;
        if (parameterAveragingAggregationTuple == null) {
            return parameterAveragingAggregationTuple2;
        }
        if (parameterAveragingAggregationTuple2 == null) {
            return parameterAveragingAggregationTuple;
        }
        if (parameterAveragingAggregationTuple.getParametersSum() == null) {
            return parameterAveragingAggregationTuple2;
        }
        if (parameterAveragingAggregationTuple2.getParametersSum() == null) {
            return parameterAveragingAggregationTuple;
        }
        INDArray addi = parameterAveragingAggregationTuple.getParametersSum().addi(parameterAveragingAggregationTuple2.getParametersSum());
        UpdaterAggregator updaterAggregator2 = parameterAveragingAggregationTuple.getUpdaterAggregator();
        UpdaterAggregator updaterAggregator3 = parameterAveragingAggregationTuple2.getUpdaterAggregator();
        if (updaterAggregator2 == null) {
            updaterAggregator = updaterAggregator3;
        } else if (updaterAggregator3 == null) {
            updaterAggregator = updaterAggregator2;
        } else {
            updaterAggregator2.merge(updaterAggregator3);
            updaterAggregator = updaterAggregator2;
        }
        ComputationGraphUpdater.Aggregator updaterAggregatorGraph = parameterAveragingAggregationTuple.getUpdaterAggregatorGraph();
        ComputationGraphUpdater.Aggregator updaterAggregatorGraph2 = parameterAveragingAggregationTuple2.getUpdaterAggregatorGraph();
        if (updaterAggregatorGraph == null) {
            aggregator = updaterAggregatorGraph2;
        } else if (updaterAggregatorGraph2 == null) {
            aggregator = updaterAggregatorGraph;
        } else {
            updaterAggregatorGraph.merge(updaterAggregatorGraph2);
            aggregator = updaterAggregatorGraph;
        }
        double scoreSum = parameterAveragingAggregationTuple.getScoreSum() + parameterAveragingAggregationTuple2.getScoreSum();
        int aggregationsCount = parameterAveragingAggregationTuple.getAggregationsCount() + parameterAveragingAggregationTuple2.getAggregationsCount();
        SparkTrainingStats sparkTrainingStats = parameterAveragingAggregationTuple.getSparkTrainingStats();
        if (parameterAveragingAggregationTuple2.getSparkTrainingStats() != null) {
            if (sparkTrainingStats == null) {
                sparkTrainingStats = parameterAveragingAggregationTuple2.getSparkTrainingStats();
            } else {
                sparkTrainingStats.addOtherTrainingStats(parameterAveragingAggregationTuple2.getSparkTrainingStats());
            }
        }
        return new ParameterAveragingAggregationTuple(addi, updaterAggregator, aggregator, scoreSum, aggregationsCount, sparkTrainingStats);
    }
}
