/*
 * Decompiled with CFR 0.152.
 */
package org.platanios.tensorflow.api.ops.training.distribute.ops;

import java.io.Serializable;
import org.platanios.tensorflow.api.core.DeviceSpecification;
import org.platanios.tensorflow.api.core.DeviceSpecification$;
import org.platanios.tensorflow.api.core.Devices$;
import org.platanios.tensorflow.api.core.client.SessionConfig;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Op$;
import org.platanios.tensorflow.api.ops.OutputLike;
import org.platanios.tensorflow.api.ops.training.distribute.Destination;
import org.platanios.tensorflow.api.ops.training.distribute.Distributable$;
import org.platanios.tensorflow.api.ops.training.distribute.Reduction;
import org.platanios.tensorflow.api.ops.training.distribute.ops.AllReduceCrossTowerOps;
import org.platanios.tensorflow.api.ops.training.distribute.ops.AllReduceCrossTowerOps$;
import org.platanios.tensorflow.api.ops.training.distribute.ops.AllReduceCrossTowerOps$HierarchicalCopy$;
import org.platanios.tensorflow.api.ops.training.distribute.ops.AllReduceCrossTowerOps$NCCL$;
import org.platanios.tensorflow.api.ops.training.distribute.ops.CrossTowerOps;
import org.platanios.tensorflow.api.ops.training.distribute.ops.SingleDeviceReduceCrossTowerOps$;
import org.platanios.tensorflow.api.ops.training.distribute.package$;
import org.platanios.tensorflow.api.ops.training.distribute.packers.ConcatenateAndSplitPacker$;
import org.platanios.tensorflow.api.ops.training.distribute.values.MirroredValue;
import org.platanios.tensorflow.api.ops.training.distribute.values.MirroredValue$;
import org.platanios.tensorflow.api.ops.training.distribute.values.PerDeviceValue;
import org.tensorflow.framework.DeviceAttributes;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.Set;
import scala.collection.SetLike;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.mutable.Buffer$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class CrossTowerOps$ {
    public static CrossTowerOps$ MODULE$;
    private final Seq<scala.collection.immutable.Set<Object>> DGX1_LINKS;

    static {
        new CrossTowerOps$();
    }

    public CrossTowerOps best(scala.collection.immutable.Set<DeviceSpecification> requestedDevices, Option<SessionConfig> sessionConfig) {
        CrossTowerOps crossTowerOps;
        Seq<DeviceAttributes> machineDevices = Devices$.MODULE$.local(sessionConfig);
        Seq usingDevices = (Seq)machineDevices.filter((Function1 & Serializable & scala.Serializable)d -> BoxesRunTime.boxToBoolean((boolean)CrossTowerOps$.$anonfun$best$1(requestedDevices, d)));
        if (usingDevices.size() != requestedDevices.size()) {
            BoxedUnit boxedUnit;
            if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
                package$.MODULE$.logger().underlying().info("Not all devices requested in the distribute strategy are visible to TensorFlow sessions and thus, defaulting to 'SingleDeviceReduceCrossTowerOps'..");
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
            crossTowerOps = SingleDeviceReduceCrossTowerOps$.MODULE$.apply(SingleDeviceReduceCrossTowerOps$.MODULE$.apply$default$1());
        } else if (usingDevices.exists((Function1 & Serializable & scala.Serializable)x$1 -> BoxesRunTime.boxToBoolean((boolean)CrossTowerOps$.$anonfun$best$2(x$1)))) {
            BoxedUnit boxedUnit;
            if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
                package$.MODULE$.logger().underlying().info("Non-GPU devices do not support all-reduce cross-tower ops and thus, defaulting to 'SingleDeviceReduceCrossTowerOps'.");
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
            crossTowerOps = SingleDeviceReduceCrossTowerOps$.MODULE$.apply(SingleDeviceReduceCrossTowerOps$.MODULE$.apply$default$1());
        } else {
            Seq deviceLinks = (Seq)usingDevices.map((Function1 & Serializable & scala.Serializable)x$2 -> ((TraversableOnce)((TraversableLike)JavaConverters$.MODULE$.asScalaBufferConverter(x$2.getLocality().getLinks().getLinkList()).asScala()).map((Function1 & Serializable & scala.Serializable)x$3 -> BoxesRunTime.boxToInteger((int)x$3.getDeviceId()), Buffer$.MODULE$.canBuildFrom())).toSet(), Seq$.MODULE$.canBuildFrom());
            crossTowerOps = this.pickAllReduceAlgorithm((Seq<scala.collection.immutable.Set<Object>>)deviceLinks);
        }
        return crossTowerOps;
    }

    public Option<SessionConfig> best$default$2() {
        return None$.MODULE$;
    }

    private CrossTowerOps pickAllReduceAlgorithm(Seq<scala.collection.immutable.Set<Object>> deviceLinks) {
        AllReduceCrossTowerOps allReduceCrossTowerOps;
        boolean hasDGX1LikeLinks = ((IterableLike)((IterableLike)deviceLinks.zip(this.DGX1_LINKS(), Seq$.MODULE$.canBuildFrom())).zipWithIndex(Seq$.MODULE$.canBuildFrom())).forall((Function1 & Serializable & scala.Serializable)links -> BoxesRunTime.boxToBoolean((boolean)CrossTowerOps$.$anonfun$pickAllReduceAlgorithm$1(links)));
        if (hasDGX1LikeLinks) {
            BoxedUnit boxedUnit;
            if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
                package$.MODULE$.logger().underlying().info(new StringBuilder(68).append("Configured cross-tower ops to use the hierarchical copy all-reduce, ").append(new StringBuilder(49).append("using a concatenate-and-split packer with ").append(deviceLinks.size()).append(" packs.").toString()).toString());
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
            allReduceCrossTowerOps = AllReduceCrossTowerOps$.MODULE$.apply(ConcatenateAndSplitPacker$.MODULE$.apply(deviceLinks.size()), AllReduceCrossTowerOps$HierarchicalCopy$.MODULE$);
        } else {
            BoxedUnit boxedUnit;
            if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
                package$.MODULE$.logger().underlying().info("Configured cross-tower ops to use the NCCL all-reduce, using a concatenate-and-split packer with 1 pack.");
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
            allReduceCrossTowerOps = AllReduceCrossTowerOps$.MODULE$.apply(ConcatenateAndSplitPacker$.MODULE$.apply(deviceLinks.size()), AllReduceCrossTowerOps$NCCL$.MODULE$);
        }
        return allReduceCrossTowerOps;
    }

    private Seq<scala.collection.immutable.Set<Object>> DGX1_LINKS() {
        return this.DGX1_LINKS;
    }

    public <O extends OutputLike, D> MirroredValue<O> simpleBroadcast(O value, D destination, Destination<D> evidence$4) {
        Map index = ((TraversableOnce)((Destination)Predef$.MODULE$.implicitly(evidence$4)).devices(destination).map((Function1 & Serializable & scala.Serializable)d -> (Tuple2)Op$.MODULE$.device(d.toString(), Op$.MODULE$.device$default$2(), (Function0 & Serializable & scala.Serializable)() -> Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(d), (Object)Basic$.MODULE$.identity(value, Basic$.MODULE$.identity$default$2()))), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
        return MirroredValue$.MODULE$.apply(index, Distributable$.MODULE$.outputLikeDistributable(Predef$.MODULE$.$conforms()));
    }

    public OutputLike simpleReduce(PerDeviceValue<OutputLike> value, DeviceSpecification reduceToDevice, Reduction reduction) {
        Seq values = value.index().values().toSeq();
        if (values.isEmpty()) {
            throw new IllegalArgumentException("The value being reduced must be non-empty.");
        }
        return (OutputLike)Op$.MODULE$.device(reduceToDevice.toString(), Op$.MODULE$.device$default$2(), (Function0 & Serializable & scala.Serializable)() -> reduction.reduce((Seq<OutputLike>)values, reduction.reduce$default$2()));
    }

    /*
     * WARNING - void declaration
     */
    public static final /* synthetic */ boolean $anonfun$best$1(scala.collection.immutable.Set requestedDevices$1, DeviceAttributes d) {
        void var3_3;
        BoxedUnit boxedUnit;
        String name = d.getName();
        boolean requested = requestedDevices$1.contains((Object)DeviceSpecification$.MODULE$.fromString(name));
        if (!requested) {
            if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
                package$.MODULE$.logger().underlying().info("Available device not used by the distribution strategy because it was not requested: {}.", new Object[]{name});
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
        return (boolean)var3_3;
    }

    public static final /* synthetic */ boolean $anonfun$best$2(DeviceAttributes x$1) {
        String string = x$1.getDeviceType().toLowerCase();
        String string2 = "gpu";
        return string == null ? string2 != null : !string.equals(string2);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static final /* synthetic */ boolean $anonfun$pickAllReduceAlgorithm$1(Tuple2 links) {
        if (BoxesRunTime.equals((Object)((Tuple2)links._1())._1(), (Object)((Tuple2)links._1())._2())) return true;
        Object object = ((Tuple2)links._1())._1();
        Set set = ((SetLike)((Tuple2)links._1())._2()).$minus((Object)BoxesRunTime.boxToInteger((int)links._2$mcI$sp()));
        if (object != null) {
            if (!object.equals(set)) return false;
            return true;
        }
        if (set == null) return true;
        return false;
    }

    private CrossTowerOps$() {
        MODULE$ = this;
        this.DGX1_LINKS = (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new scala.collection.immutable.Set[]{(scala.collection.immutable.Set)Predef$.MODULE$.Set().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{0, 1, 2, 3, 4})), (scala.collection.immutable.Set)Predef$.MODULE$.Set().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{0, 1, 2, 3, 5})), (scala.collection.immutable.Set)Predef$.MODULE$.Set().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{0, 1, 2, 3, 6})), (scala.collection.immutable.Set)Predef$.MODULE$.Set().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{0, 1, 2, 3, 7})), (scala.collection.immutable.Set)Predef$.MODULE$.Set().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{0, 4, 5, 6, 7})), (scala.collection.immutable.Set)Predef$.MODULE$.Set().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{1, 4, 5, 6, 7})), (scala.collection.immutable.Set)Predef$.MODULE$.Set().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{2, 4, 5, 6, 7})), (scala.collection.immutable.Set)Predef$.MODULE$.Set().apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{3, 4, 5, 6, 7}))}));
    }
}

