package org.platanios.tensorflow.api.ops.training.distribute.ops;

import org.platanios.tensorflow.api.core.DeviceSpecification;
import org.platanios.tensorflow.api.ops.OutputLike;
import org.platanios.tensorflow.api.ops.training.distribute.Distributable$;
import org.platanios.tensorflow.api.ops.training.distribute.Reduction;
import org.platanios.tensorflow.api.ops.training.distribute.ops.AllReduceCrossTowerOps;
import org.platanios.tensorflow.api.ops.training.distribute.packers.Packer;
import org.platanios.tensorflow.api.ops.training.distribute.values.DistributedValue;
import org.platanios.tensorflow.api.ops.training.distribute.values.MirroredValue;
import org.platanios.tensorflow.api.ops.training.distribute.values.MirroredValue$;
import org.platanios.tensorflow.api.ops.training.distribute.values.PerDeviceValue;
import org.platanios.tensorflow.api.package$;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableOnce;
import scala.collection.generic.GenericTraversableTemplate;

/* compiled from: AllReduceCrossTowerOps.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/training/distribute/ops/AllReduceCrossTowerOps$.class */
public final class AllReduceCrossTowerOps$ {
    public static AllReduceCrossTowerOps$ MODULE$;

    static {
        new AllReduceCrossTowerOps$();
    }

    public <P> AllReduceCrossTowerOps<P> apply(Packer<P> packer, AllReduceCrossTowerOps.Algorithm algorithm) {
        return new AllReduceCrossTowerOps<>(packer, algorithm);
    }

    public <T> Seq<Seq<T>> groupValueByDevice(Seq<PerDeviceValue<T>> seq) throws IllegalArgumentException {
        Seq<DeviceSpecification> devices = ((DistributedValue) seq.head()).devices();
        return ((GenericTraversableTemplate) seq.map(perDeviceValue -> {
            Seq<DeviceSpecification> devices2 = perDeviceValue.devices();
            if (devices2 != null ? devices2.equals(devices) : devices == null) {
                return perDeviceValue.index().values();
            }
            throw new IllegalArgumentException("The values are not all distributed on the same devices.");
        }, Seq$.MODULE$.canBuildFrom())).transpose(Predef$.MODULE$.$conforms());
    }

    public Seq<MirroredValue<OutputLike>> ungroupToMirrored(Seq<Seq<OutputLike>> seq, Seq<DeviceSpecification> seq2, Reduction reduction) {
        return (Seq) seq.transpose(Predef$.MODULE$.$conforms()).map(seq3 -> {
            return MirroredValue$.MODULE$.apply(((TraversableOnce) seq2.zip((Seq) seq3.map(outputLike -> {
                return reduction.processUngroupedValue(package$.MODULE$.outputConvertibleToOutput(outputLike), seq2);
            }, Seq$.MODULE$.canBuildFrom()), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()), Distributable$.MODULE$.outputLikeDistributable(Predef$.MODULE$.$conforms()));
        }, Seq$.MODULE$.canBuildFrom());
    }

    private AllReduceCrossTowerOps$() {
        MODULE$ = this;
    }
}
