package org.platanios.tensorflow.api.ops.training.distribute.ops;

import org.platanios.tensorflow.api.core.DeviceSpecification;
import org.platanios.tensorflow.api.core.DeviceSpecification$;
import org.platanios.tensorflow.api.core.Devices$;
import org.platanios.tensorflow.api.core.client.SessionConfig;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Op$;
import org.platanios.tensorflow.api.ops.OutputLike;
import org.platanios.tensorflow.api.ops.training.distribute.Destination;
import org.platanios.tensorflow.api.ops.training.distribute.Distributable$;
import org.platanios.tensorflow.api.ops.training.distribute.Reduction;
import org.platanios.tensorflow.api.ops.training.distribute.package$;
import org.platanios.tensorflow.api.ops.training.distribute.packers.ConcatenateAndSplitPacker$;
import org.platanios.tensorflow.api.ops.training.distribute.values.MirroredValue;
import org.platanios.tensorflow.api.ops.training.distribute.values.MirroredValue$;
import org.platanios.tensorflow.api.ops.training.distribute.values.PerDeviceValue;
import org.tensorflow.framework.DeviceAttributes;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SetLike;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Set;
import scala.collection.mutable.Buffer$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: CrossTowerOps.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/training/distribute/ops/CrossTowerOps$.class */
public final class CrossTowerOps$ {
    public static CrossTowerOps$ MODULE$;
    private final Seq<Set<Object>> DGX1_LINKS;

    static {
        new CrossTowerOps$();
    }

    public CrossTowerOps best(Set<DeviceSpecification> set, Option<SessionConfig> option) {
        Seq seq = (Seq) Devices$.MODULE$.local(option).filter(deviceAttributes -> {
            return BoxesRunTime.boxToBoolean($anonfun$best$1(set, deviceAttributes));
        });
        if (seq.size() != set.size()) {
            if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
                package$.MODULE$.logger().underlying().info("Not all devices requested in the distribute strategy are visible to TensorFlow sessions and thus, defaulting to 'SingleDeviceReduceCrossTowerOps'..");
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
            return SingleDeviceReduceCrossTowerOps$.MODULE$.apply(SingleDeviceReduceCrossTowerOps$.MODULE$.apply$default$1());
        }
        if (!seq.exists(deviceAttributes2 -> {
            return BoxesRunTime.boxToBoolean($anonfun$best$2(deviceAttributes2));
        })) {
            return pickAllReduceAlgorithm((Seq) seq.map(deviceAttributes3 -> {
                return ((TraversableOnce) ((TraversableLike) JavaConverters$.MODULE$.asScalaBufferConverter(deviceAttributes3.getLocality().getLinks().getLinkList()).asScala()).map(interconnectLink -> {
                    return BoxesRunTime.boxToInteger(interconnectLink.getDeviceId());
                }, Buffer$.MODULE$.canBuildFrom())).toSet();
            }, Seq$.MODULE$.canBuildFrom()));
        }
        if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
            package$.MODULE$.logger().underlying().info("Non-GPU devices do not support all-reduce cross-tower ops and thus, defaulting to 'SingleDeviceReduceCrossTowerOps'.");
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
        }
        return SingleDeviceReduceCrossTowerOps$.MODULE$.apply(SingleDeviceReduceCrossTowerOps$.MODULE$.apply$default$1());
    }

    public Option<SessionConfig> best$default$2() {
        return None$.MODULE$;
    }

    private CrossTowerOps pickAllReduceAlgorithm(Seq<Set<Object>> seq) {
        if (((IterableLike) ((IterableLike) seq.zip(DGX1_LINKS(), Seq$.MODULE$.canBuildFrom())).zipWithIndex(Seq$.MODULE$.canBuildFrom())).forall(tuple2 -> {
            return BoxesRunTime.boxToBoolean($anonfun$pickAllReduceAlgorithm$1(tuple2));
        })) {
            if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
                package$.MODULE$.logger().underlying().info(new StringBuilder(68).append("Configured cross-tower ops to use the hierarchical copy all-reduce, ").append(new StringBuilder(49).append("using a concatenate-and-split packer with ").append(seq.size()).append(" packs.").toString()).toString());
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
            return AllReduceCrossTowerOps$.MODULE$.apply(ConcatenateAndSplitPacker$.MODULE$.apply(seq.size()), AllReduceCrossTowerOps$HierarchicalCopy$.MODULE$);
        }
        if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
            package$.MODULE$.logger().underlying().info("Configured cross-tower ops to use the NCCL all-reduce, using a concatenate-and-split packer with 1 pack.");
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
        }
        return AllReduceCrossTowerOps$.MODULE$.apply(ConcatenateAndSplitPacker$.MODULE$.apply(seq.size()), AllReduceCrossTowerOps$NCCL$.MODULE$);
    }

    private Seq<Set<Object>> DGX1_LINKS() {
        return this.DGX1_LINKS;
    }

    public <O extends OutputLike, D> MirroredValue<O> simpleBroadcast(O o, D d, Destination<D> destination) {
        return MirroredValue$.MODULE$.apply(((TraversableOnce) ((Destination) Predef$.MODULE$.implicitly(destination)).devices(d).map(deviceSpecification -> {
            return (Tuple2) Op$.MODULE$.device(deviceSpecification.toString(), Op$.MODULE$.device$default$2(), () -> {
                return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(deviceSpecification), Basic$.MODULE$.identity(o, Basic$.MODULE$.identity$default$2()));
            });
        }, Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()), Distributable$.MODULE$.outputLikeDistributable(Predef$.MODULE$.$conforms()));
    }

    public OutputLike simpleReduce(PerDeviceValue<OutputLike> perDeviceValue, DeviceSpecification deviceSpecification, Reduction reduction) {
        Seq seq = perDeviceValue.index().values().toSeq();
        if (seq.isEmpty()) {
            throw new IllegalArgumentException("The value being reduced must be non-empty.");
        }
        return (OutputLike) Op$.MODULE$.device(deviceSpecification.toString(), Op$.MODULE$.device$default$2(), () -> {
            return reduction.reduce(seq, reduction.reduce$default$2());
        });
    }

    public static final /* synthetic */ boolean $anonfun$best$1(Set set, DeviceAttributes deviceAttributes) {
        String name = deviceAttributes.getName();
        boolean contains = set.contains(DeviceSpecification$.MODULE$.fromString(name));
        if (contains) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (package$.MODULE$.logger().underlying().isInfoEnabled()) {
            package$.MODULE$.logger().underlying().info("Available device not used by the distribution strategy because it was not requested: {}.", new Object[]{name});
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        return contains;
    }

    public static final /* synthetic */ boolean $anonfun$best$2(DeviceAttributes deviceAttributes) {
        String lowerCase = deviceAttributes.getDeviceType().toLowerCase();
        return lowerCase != null ? !lowerCase.equals("gpu") : "gpu" != 0;
    }

    public static final /* synthetic */ boolean $anonfun$pickAllReduceAlgorithm$1(Tuple2 tuple2) {
        if (!BoxesRunTime.equals(((Tuple2) tuple2._1())._1(), ((Tuple2) tuple2._1())._2())) {
            Object _1 = ((Tuple2) tuple2._1())._1();
            scala.collection.Set $minus = ((SetLike) ((Tuple2) tuple2._1())._2()).$minus(BoxesRunTime.boxToInteger(tuple2._2$mcI$sp()));
            if (_1 != null ? !_1.equals($minus) : $minus != null) {
                return false;
            }
        }
        return true;
    }

    private CrossTowerOps$() {
        MODULE$ = this;
        this.DGX1_LINKS = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Set[]{(Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapIntArray(new int[]{0, 1, 2, 3, 4})), (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapIntArray(new int[]{0, 1, 2, 3, 5})), (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapIntArray(new int[]{0, 1, 2, 3, 6})), (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapIntArray(new int[]{0, 1, 2, 3, 7})), (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapIntArray(new int[]{0, 4, 5, 6, 7})), (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapIntArray(new int[]{1, 4, 5, 6, 7})), (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapIntArray(new int[]{2, 4, 5, 6, 7})), (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapIntArray(new int[]{3, 4, 5, 6, 7}))}));
    }
}
