package org.deeplearning4j.spark.impl.common.updater;

import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;

/* loaded from: input_file:org/deeplearning4j/spark/impl/common/updater/UpdaterElementCombinerCG.class */
public class UpdaterElementCombinerCG implements Function2<ComputationGraphUpdater.Aggregator, ComputationGraphUpdater, ComputationGraphUpdater.Aggregator> {
    public ComputationGraphUpdater.Aggregator call(ComputationGraphUpdater.Aggregator aggregator, ComputationGraphUpdater computationGraphUpdater) throws Exception {
        if (aggregator == null && computationGraphUpdater == null) {
            return null;
        }
        if (aggregator == null) {
            return computationGraphUpdater.getAggregator(true);
        }
        if (computationGraphUpdater == null) {
            return aggregator;
        }
        aggregator.aggregate(computationGraphUpdater);
        return aggregator;
    }
}
