package org.platanios.tensorflow.api.ops;

import org.platanios.tensorflow.api.ops.Gradients;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.Set;
import scala.collection.mutable.Map;
import scala.collection.mutable.Seq;
import scala.math.Ordering$String$;
import scala.runtime.BoxesRunTime;

/* compiled from: Gradients.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/Gradients$AccumulateAggregationMethod$.class */
public class Gradients$AccumulateAggregationMethod$ implements Gradients.AggregationMethod {
    public static Gradients$AccumulateAggregationMethod$ MODULE$;

    static {
        new Gradients$AccumulateAggregationMethod$();
    }

    @Override // org.platanios.tensorflow.api.ops.Gradients.AggregationMethod
    public Seq<scala.collection.Seq<OutputLike>> aggregateGradients(Map<Op, Seq<scala.collection.Seq<OutputLike>>> map, Op op, String str) {
        Seq<scala.collection.Seq<OutputLike>> aggregateGradients;
        aggregateGradients = aggregateGradients(map, op, str);
        return aggregateGradients;
    }

    @Override // org.platanios.tensorflow.api.ops.Gradients.AggregationMethod
    public OutputLike aggregate(scala.collection.Seq<OutputLike> seq, Option<String> option) {
        if (seq.forall(outputLike -> {
            return BoxesRunTime.boxToBoolean($anonfun$aggregate$15(outputLike));
        })) {
            return Math$.MODULE$.accumulateN((scala.collection.Seq) seq.map(outputLike2 -> {
                return (Output) outputLike2;
            }, Seq$.MODULE$.canBuildFrom()), Math$.MODULE$.accumulateN$default$2(), Math$.MODULE$.accumulateN$default$3());
        }
        if (seq.forall(outputLike3 -> {
            return BoxesRunTime.boxToBoolean($anonfun$aggregate$17(outputLike3));
        })) {
            return addNOutputIndexedSlices$2((scala.collection.Seq) ((TraversableLike) seq.groupBy(outputLike4 -> {
                return outputLike4.device();
            }).toSeq().sortBy(tuple2 -> {
                return (String) tuple2._1();
            }, Ordering$String$.MODULE$)).map(tuple22 -> {
                if (tuple22 == null) {
                    throw new MatchError(tuple22);
                }
                scala.collection.Seq seq2 = (scala.collection.Seq) tuple22._2();
                return (OutputIndexedSlices) Op$.MODULE$.colocateWithForGradient((Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new Op[]{((OutputLike) seq.head()).op()})), option, true, () -> {
                    return addNOutputIndexedSlices$2((scala.collection.Seq) seq2.map(outputLike5 -> {
                        return (OutputIndexedSlices) outputLike5;
                    }, Seq$.MODULE$.canBuildFrom()));
                });
            }, Seq$.MODULE$.canBuildFrom()));
        }
        throw new IllegalArgumentException("The gradients being aggregated need to be all of type 'Output' or 'OutputIndexedSlices'.");
    }

    @Override // org.platanios.tensorflow.api.ops.Gradients.AggregationMethod
    public Option<String> aggregate$default$2() {
        return None$.MODULE$;
    }

    public static final /* synthetic */ boolean $anonfun$aggregate$15(OutputLike outputLike) {
        return outputLike instanceof Output;
    }

    public static final /* synthetic */ boolean $anonfun$aggregate$17(OutputLike outputLike) {
        return outputLike instanceof OutputIndexedSlices;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final OutputIndexedSlices addNOutputIndexedSlices$2(scala.collection.Seq seq) {
        if (seq.isEmpty()) {
            throw new IllegalArgumentException("Can not aggregate empty gradients list.");
        }
        return seq.length() == 1 ? (OutputIndexedSlices) seq.head() : new OutputIndexedSlices(Basic$.MODULE$.concatenate((scala.collection.Seq) seq.map(outputIndexedSlices -> {
            return outputIndexedSlices.indices();
        }, Seq$.MODULE$.canBuildFrom()), Basic$.MODULE$.concatenate$default$2(), Basic$.MODULE$.concatenate$default$3()), Basic$.MODULE$.concatenate((scala.collection.Seq) seq.map(outputIndexedSlices2 -> {
            return outputIndexedSlices2.values();
        }, Seq$.MODULE$.canBuildFrom()), Basic$.MODULE$.concatenate$default$2(), Basic$.MODULE$.concatenate$default$3()), ((OutputIndexedSlices) seq.head()).denseShape());
    }

    public Gradients$AccumulateAggregationMethod$() {
        MODULE$ = this;
        Gradients.AggregationMethod.$init$(this);
    }
}
