/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mxnet;

import java.io.Serializable;
import org.apache.mxnet.Base$;
import org.apache.mxnet.Context;
import org.apache.mxnet.DataBatch;
import org.apache.mxnet.DataParallelExecutorManager$;
import org.apache.mxnet.Executor;
import org.apache.mxnet.NDArray;
import org.apache.mxnet.NDArray$;
import org.apache.mxnet.Shape;
import org.apache.mxnet.Symbol;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple2$mcII$sp;
import scala.Tuple3;
import scala.collection.GenMap;
import scala.collection.GenTraversable;
import scala.collection.IndexedSeq;
import scala.collection.IndexedSeq$;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.ListMap;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Set;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.HashMap$;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric$FloatIsFractional$;
import scala.math.Numeric$IntIsIntegral$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

public final class ExecutorManager$ {
    public static final ExecutorManager$ MODULE$;

    static {
        new ExecutorManager$();
    }

    public Tuple2<Object, Object>[] splitInputSlice(int batchSize, Seq<Object> workLoadList) {
        float totalWorkLoad = BoxesRunTime.unboxToFloat(workLoadList.sum(Numeric$FloatIsFractional$.MODULE$));
        int[] batchNumList = (int[])((TraversableOnce)workLoadList.map(new scala.Serializable(batchSize, totalWorkLoad){
            public static final long serialVersionUID = 0L;
            private final int batchSize$1;
            private final float totalWorkLoad$1;

            public final int apply(float workLoad) {
                return this.apply$mcIF$sp(workLoad);
            }

            public int apply$mcIF$sp(float workLoad) {
                return package$.MODULE$.round(workLoad * (float)this.batchSize$1 / this.totalWorkLoad$1);
            }
            {
                this.batchSize$1 = batchSize$1;
                this.totalWorkLoad$1 = totalWorkLoad$1;
            }
        }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Int());
        int batchNumSum = BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(batchNumList).sum(Numeric$IntIsIntegral$.MODULE$));
        if (batchNumSum < batchSize) {
            int n = batchNumList.length - 1;
            batchNumList[n] = batchNumList[n] + (batchSize - batchNumSum);
        }
        ArrayBuffer slices = (ArrayBuffer)ArrayBuffer$.MODULE$.empty();
        IntRef end = IntRef.create(0);
        Predef$.MODULE$.intArrayOps(batchNumList).foreach(new scala.Serializable(batchSize, slices, end){
            public static final long serialVersionUID = 0L;
            private final int batchSize$1;
            private final ArrayBuffer slices$1;
            private final IntRef end$1;

            public final void apply(int batchNum) {
                this.apply$mcVI$sp(batchNum);
            }

            public void apply$mcVI$sp(int batchNum) {
                int begin = package$.MODULE$.min(this.end$1.elem, this.batchSize$1);
                this.end$1.elem = package$.MODULE$.min(begin + batchNum, this.batchSize$1);
                Predef$.MODULE$.require(begin < this.end$1.elem, (Function0<Object>)((Object)new scala.Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final String apply() {
                        return "Too many slices such that some splits are empty";
                    }
                }));
                this.slices$1.append(Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{new Tuple2$mcII$sp(begin, this.end$1.elem)}));
            }
            {
                this.batchSize$1 = batchSize$1;
                this.slices$1 = slices$1;
                this.end$1 = end$1;
            }
        });
        return (Tuple2[])slices.toArray(ClassTag$.MODULE$.apply(Tuple2.class));
    }

    public void checkArguments(Symbol symbol) {
        IndexedSeq<String> argNames = symbol.listArguments();
        Predef$.MODULE$.require(argNames.toSet().size() == argNames.length(), (Function0<Object>)((Object)new scala.Serializable(argNames){
            public static final long serialVersionUID = 0L;
            private final IndexedSeq argNames$1;

            public final String apply() {
                return new StringBuilder().append((Object)"Find duplicated argument name,please make the weight name non-duplicated(using name arguments),").append((Object)new StringContext(Predef$.MODULE$.wrapRefArray((Object[])new String[]{"arguments are ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{this.argNames$1}))).toString();
            }
            {
                this.argNames$1 = argNames$1;
            }
        }));
        IndexedSeq<String> auxNames = symbol.listAuxiliaryStates();
        Predef$.MODULE$.require(auxNames.toSet().size() == auxNames.length(), (Function0<Object>)((Object)new scala.Serializable(auxNames){
            public static final long serialVersionUID = 0L;
            private final IndexedSeq auxNames$1;

            public final String apply() {
                return new StringBuilder().append((Object)"Find duplicated auxiliary param name,please make the weight name non-duplicated(using name arguments),").append((Object)new StringContext(Predef$.MODULE$.wrapRefArray((Object[])new String[]{"arguments are ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{this.auxNames$1}))).toString();
            }
            {
                this.auxNames$1 = auxNames$1;
            }
        }));
    }

    public void loadGeneral(Seq<NDArray> data, Seq<NDArray> targets) {
        ((IterableLike)data.zip(targets, Seq$.MODULE$.canBuildFrom())).foreach(new scala.Serializable(){
            public static final long serialVersionUID = 0L;

            public final NDArray apply(Tuple2<NDArray, NDArray> x0$3) {
                Tuple2<NDArray, NDArray> tuple2 = x0$3;
                if (tuple2 != null) {
                    NDArray dSrc = tuple2._1();
                    NDArray dTarget = tuple2._2();
                    Shape shape2 = dSrc.shape();
                    Shape shape3 = dTarget.shape();
                    Predef$.MODULE$.require(!(shape2 != null ? !((Object)shape2).equals(shape3) : shape3 != null), (Function0<Object>)((Object)new scala.Serializable(this, dSrc, dTarget){
                        public static final long serialVersionUID = 0L;
                        private final NDArray dSrc$1;
                        private final NDArray dTarget$1;

                        public final String apply() {
                            return new StringContext(Predef$.MODULE$.wrapRefArray((Object[])new String[]{"src shape ", " mismatch dst shape ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{this.dSrc$1.shape(), this.dTarget$1.shape()}));
                        }
                        {
                            this.dSrc$1 = dSrc$1;
                            this.dTarget$1 = dTarget$1;
                        }
                    }));
                    NDArray nDArray = dSrc.copyTo(dTarget);
                    return nDArray;
                }
                throw new MatchError(tuple2);
            }
        });
    }

    public void loadGeneralMulti(Seq<NDArray> data, Seq<Tuple3<Object, Object, NDArray>[]> targets) {
        ((TraversableLike)data.zip(targets, Seq$.MODULE$.canBuildFrom())).withFilter(new scala.Serializable(){
            public static final long serialVersionUID = 0L;

            public final boolean apply(Tuple2<NDArray, Tuple3<Object, Object, NDArray>[]> check$ifrefutable$3) {
                Tuple2<NDArray, Tuple3<Object, Object, NDArray>[]> tuple2 = check$ifrefutable$3;
                boolean bl = tuple2 != null;
                return bl;
            }
        }).foreach(new scala.Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(Tuple2<NDArray, Tuple3<Object, Object, NDArray>[]> x$12) {
                Tuple2<NDArray, Tuple3<Object, Object, NDArray>[]> tuple2 = x$12;
                if (tuple2 != null) {
                    NDArray src = tuple2._1();
                    Tuple3<Object, Object, NDArray>[] dTargets = tuple2._2();
                    Predef$.MODULE$.refArrayOps((Object[])dTargets).withFilter(new scala.Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final boolean apply(Tuple3<Object, Object, NDArray> check$ifrefutable$4) {
                            Tuple3<Object, Object, NDArray> tuple3 = check$ifrefutable$4;
                            boolean bl = tuple3 != null;
                            return bl;
                        }
                    }).foreach(new scala.Serializable(this, src){
                        public static final long serialVersionUID = 0L;
                        private final NDArray src$1;

                        public final NDArray apply(Tuple3<Object, Object, NDArray> x$11) {
                            Tuple3<Object, Object, NDArray> tuple3 = x$11;
                            if (tuple3 != null) {
                                int start2 = BoxesRunTime.unboxToInt(tuple3._1());
                                int end = BoxesRunTime.unboxToInt(tuple3._2());
                                NDArray dst = tuple3._3();
                                NDArray sliced = this.src$1.slice(start2, end);
                                Shape shape2 = sliced.shape();
                                Shape shape3 = dst.shape();
                                Predef$.MODULE$.require(!(shape2 != null ? !((Object)shape2).equals(shape3) : shape3 != null), (Function0<Object>)((Object)new scala.Serializable(this, dst, sliced){
                                    public static final long serialVersionUID = 0L;
                                    private final NDArray dst$1;
                                    private final NDArray sliced$1;

                                    public final String apply() {
                                        return new StringContext(Predef$.MODULE$.wrapRefArray((Object[])new String[]{"src shape ", " mismatch dst shape ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{this.sliced$1.shape(), this.dst$1.shape()}));
                                    }
                                    {
                                        this.dst$1 = dst$1;
                                        this.sliced$1 = sliced$1;
                                    }
                                }));
                                NDArray nDArray = sliced.copyTo(dst);
                                return nDArray;
                            }
                            throw new MatchError(tuple3);
                        }
                        {
                            this.src$1 = src$1;
                        }
                    });
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                    return;
                }
                throw new MatchError(tuple2);
            }
        });
    }

    public void loadDataMulti(DataBatch batch, Seq<Tuple3<Object, Object, NDArray>[]> targets) {
        this.loadGeneralMulti(batch.data(), targets);
    }

    public void loadData(DataBatch batch, Seq<NDArray> targets) {
        this.loadGeneral(batch.data(), targets);
    }

    public void loadLabelMulti(DataBatch batch, Seq<Tuple3<Object, Object, NDArray>[]> targets) {
        this.loadGeneralMulti(batch.label(), targets);
    }

    public void loadLabel(DataBatch batch, Seq<NDArray> targets) {
        this.loadGeneral(batch.label(), targets);
    }

    public Executor bindExec(Symbol sym, Context ctx, Map<String, Shape> inputShapes, Set<String> paramNames, boolean needGrad, Set<String> grads, Executor baseExec, scala.collection.mutable.Map<String, NDArray> sharedDataArrays, ListMap<String, Enumeration.Value> inputTypes) {
        Tuple3<IndexedSeq<Shape>, IndexedSeq<Shape>, IndexedSeq<Shape>> tuple3 = sym.inferShape(inputShapes);
        if (tuple3 != null) {
            Tuple2<IndexedSeq<Shape>, IndexedSeq<Shape>> tuple2;
            IndexedSeq<Shape> argShape = tuple3._1();
            IndexedSeq<Shape> auxShape = tuple3._3();
            Tuple2<IndexedSeq<Shape>, IndexedSeq<Shape>> tuple22 = tuple2 = new Tuple2<IndexedSeq<Shape>, IndexedSeq<Shape>>(argShape, auxShape);
            IndexedSeq<Shape> argShape2 = tuple22._1();
            IndexedSeq<Shape> auxShape2 = tuple22._2();
            Predef$.MODULE$.require(argShape2 != null);
            ListMap<String, Enumeration.Value> inputTypesUpdate = inputTypes == null ? inputShapes.map((Function1<String, Shape>)((Object)new scala.Serializable(){
                public static final long serialVersionUID = 0L;

                public final Tuple2<String, Enumeration.Value> apply(Tuple2<String, Shape> x0$4) {
                    Tuple2<String, Shape> tuple2 = x0$4;
                    if (tuple2 != null) {
                        String key = tuple2._1();
                        Tuple2<String, Enumeration.Value> tuple22 = new Tuple2<String, Enumeration.Value>(key, Base$.MODULE$.MX_REAL_TYPE());
                        return tuple22;
                    }
                    throw new MatchError(tuple2);
                }
            }), Map$.MODULE$.canBuildFrom()) : inputTypes;
            Tuple3<Seq<Enumeration.Value>, Seq<Enumeration.Value>, Seq<Enumeration.Value>> tuple32 = sym.inferType(inputTypesUpdate);
            if (tuple32 != null) {
                Tuple2<Seq<Enumeration.Value>, Seq<Enumeration.Value>> tuple23;
                Seq<Enumeration.Value> argTypes = tuple32._1();
                Seq<Enumeration.Value> auxTypes = tuple32._3();
                Tuple2<Seq<Enumeration.Value>, Seq<Enumeration.Value>> tuple24 = tuple23 = new Tuple2<Seq<Enumeration.Value>, Seq<Enumeration.Value>>(argTypes, auxTypes);
                Seq<Enumeration.Value> argTypes2 = tuple24._1();
                Seq<Enumeration.Value> auxTypes2 = tuple24._2();
                Predef$.MODULE$.require(argTypes2 != null);
                ArrayBuffer argArrays = (ArrayBuffer)ArrayBuffer$.MODULE$.empty();
                GenMap gradArrays = needGrad ? HashMap$.MODULE$.empty() : null;
                IndexedSeq<String> argNames = sym.listArguments();
                GenTraversable<String> gradSet = needGrad ? (grads == null ? (Set)argNames.toSet().$minus$minus(inputShapes.keySet()) : grads) : Predef$.MODULE$.Set().empty();
                Map<String, String> gradReq = argNames.map(new scala.Serializable((Set)gradSet){
                    public static final long serialVersionUID = 0L;
                    private final Set gradSet$1;

                    public final Tuple2<String, String> apply(String name) {
                        return this.gradSet$1.contains(name) ? Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(name), "write") : Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(name), "null");
                    }
                    {
                        this.gradSet$1 = gradSet$1;
                    }
                }, scala.collection.package$.MODULE$.breakOut(Map$.MODULE$.canBuildFrom()));
                ((IterableLike)argNames.zipWithIndex(IndexedSeq$.MODULE$.canBuildFrom())).foreach(new scala.Serializable(ctx, paramNames, baseExec, sharedDataArrays, argShape2, argTypes2, argArrays, (scala.collection.mutable.Map)gradArrays, (Set)gradSet){
                    public static final long serialVersionUID = 0L;
                    private final Context ctx$1;
                    private final Set paramNames$1;
                    private final Executor baseExec$1;
                    private final scala.collection.mutable.Map sharedDataArrays$1;
                    private final IndexedSeq argShape$1;
                    private final Seq argTypes$1;
                    private final ArrayBuffer argArrays$1;
                    private final scala.collection.mutable.Map gradArrays$1;
                    private final Set gradSet$1;

                    public final void apply(Tuple2<String, Object> x0$5) {
                        Tuple2<String, Object> tuple2 = x0$5;
                        if (tuple2 != null) {
                            BoxedUnit boxedUnit;
                            String name = tuple2._1();
                            int i = tuple2._2$mcI$sp();
                            if (this.paramNames$1.contains(name)) {
                                NDArray nDArray;
                                if (this.baseExec$1 == null) {
                                    Serializable serializable;
                                    if (this.gradSet$1.contains(name)) {
                                        NDArray gradArr = NDArray$.MODULE$.zeros((Shape)this.argShape$1.apply(i), this.ctx$1, (Enumeration.Value)this.argTypes$1.apply(i));
                                        serializable = this.gradArrays$1.put(name, gradArr);
                                    } else {
                                        serializable = BoxedUnit.UNIT;
                                    }
                                    nDArray = NDArray$.MODULE$.zeros((Shape)this.argShape$1.apply(i), this.ctx$1, (Enumeration.Value)this.argTypes$1.apply(i));
                                } else {
                                    NDArray arr = (NDArray)this.baseExec$1.argDict().apply(name);
                                    Shape shape2 = arr.shape();
                                    R r = this.argShape$1.apply(i);
                                    Predef$.MODULE$.require(!(shape2 != null ? !((Object)shape2).equals(r) : r != null));
                                    Enumeration.Value value2 = arr.dtype();
                                    R r2 = this.argTypes$1.apply(i);
                                    Predef$.MODULE$.require(!(value2 != null ? !((Object)value2).equals(r2) : r2 != null));
                                    Serializable serializable = this.gradSet$1.contains(name) ? this.gradArrays$1.put(name, this.baseExec$1.gradDict().apply(name)) : BoxedUnit.UNIT;
                                    nDArray = arr;
                                }
                                NDArray argArr = nDArray;
                                this.argArrays$1.append(Predef$.MODULE$.wrapRefArray((Object[])new NDArray[]{argArr}));
                                boxedUnit = BoxedUnit.UNIT;
                            } else {
                                NDArray nDArray;
                                if (this.sharedDataArrays$1 != null && this.sharedDataArrays$1.contains(name)) {
                                    NDArray arr = (NDArray)this.sharedDataArrays$1.apply(name);
                                    if (arr.shape().product() >= ((Shape)this.argShape$1.apply(i)).product()) {
                                        R r = this.argTypes$1.apply(i);
                                        Enumeration.Value value3 = arr.dtype();
                                        Predef$.MODULE$.require(!(r != null ? !r.equals(value3) : value3 != null));
                                        nDArray = arr.reshape((Shape)this.argShape$1.apply(i));
                                    } else {
                                        DataParallelExecutorManager$.MODULE$.logger().warn(new StringBuilder().append((Object)new StringContext(Predef$.MODULE$.wrapRefArray((Object[])new String[]{"bucketing: data ", " has a shape ", ","})).s(Predef$.MODULE$.genericWrapArray(new Object[]{name, this.argShape$1.apply(i)}))).append((Object)new StringContext(Predef$.MODULE$.wrapRefArray((Object[])new String[]{"which is larger than already allocated shape ", "."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{arr.shape()}))).append((Object)"Need to re-allocate.Consider putting default_bucket_key").append((Object)"to be the bucket taking the largest input for better memory sharing.").toString());
                                        NDArray zeros = NDArray$.MODULE$.zeros((Shape)this.argShape$1.apply(i), this.ctx$1, (Enumeration.Value)this.argTypes$1.apply(i));
                                        this.sharedDataArrays$1.put(name, zeros);
                                        nDArray = zeros;
                                    }
                                } else {
                                    NDArray zeros = NDArray$.MODULE$.zeros((Shape)this.argShape$1.apply(i), this.ctx$1, (Enumeration.Value)this.argTypes$1.apply(i));
                                    Serializable serializable = this.sharedDataArrays$1 == null ? BoxedUnit.UNIT : this.sharedDataArrays$1.put(name, zeros);
                                    nDArray = zeros;
                                }
                                NDArray argArr = nDArray;
                                this.argArrays$1.append(Predef$.MODULE$.wrapRefArray((Object[])new NDArray[]{argArr}));
                                boxedUnit = BoxedUnit.UNIT;
                            }
                            BoxedUnit boxedUnit2 = boxedUnit;
                            return;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        this.ctx$1 = ctx$1;
                        this.paramNames$1 = paramNames$1;
                        this.baseExec$1 = baseExec$1;
                        this.sharedDataArrays$1 = sharedDataArrays$1;
                        this.argShape$1 = argShape$1;
                        this.argTypes$1 = argTypes$1;
                        this.argArrays$1 = argArrays$1;
                        this.gradArrays$1 = gradArrays$1;
                        this.gradSet$1 = gradSet$1;
                    }
                });
                Seq auxArrays = baseExec == null ? (Seq)((TraversableLike)auxShape2.zip(auxTypes2, IndexedSeq$.MODULE$.canBuildFrom())).map(new scala.Serializable(ctx){
                    public static final long serialVersionUID = 0L;
                    private final Context ctx$1;

                    public final NDArray apply(Tuple2<Shape, Enumeration.Value> x0$6) {
                        Tuple2<Shape, Enumeration.Value> tuple2 = x0$6;
                        if (tuple2 != null) {
                            Shape s2 = tuple2._1();
                            Enumeration.Value t = tuple2._2();
                            NDArray nDArray = NDArray$.MODULE$.zeros(s2, this.ctx$1, t);
                            return nDArray;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        this.ctx$1 = ctx$1;
                    }
                }, IndexedSeq$.MODULE$.canBuildFrom()) : Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])baseExec.auxArrays()).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).map(new scala.Serializable(auxShape2, auxTypes2){
                    public static final long serialVersionUID = 0L;
                    private final IndexedSeq auxShape$1;
                    private final Seq auxTypes$1;

                    public final NDArray apply(Tuple2<NDArray, Object> x0$7) {
                        Tuple2<NDArray, Object> tuple2 = x0$7;
                        if (tuple2 != null) {
                            NDArray a = tuple2._1();
                            int i = tuple2._2$mcI$sp();
                            R r = this.auxShape$1.apply(i);
                            Shape shape2 = a.shape();
                            Predef$.MODULE$.require(!(r != null ? !r.equals(shape2) : shape2 != null));
                            R r2 = this.auxTypes$1.apply(i);
                            Enumeration.Value value2 = a.dtype();
                            Predef$.MODULE$.require(!(r2 != null ? !r2.equals(value2) : value2 != null));
                            NDArray nDArray = a;
                            return nDArray;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        this.auxShape$1 = auxShape$1;
                        this.auxTypes$1 = auxTypes$1;
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(NDArray.class)))).toSeq();
                return sym.bind(ctx, (Seq<NDArray>)argArrays.toSeq(), gradArrays.toMap(Predef$.MODULE$.$conforms()), gradReq, (Seq<NDArray>)auxArrays, null, baseExec);
            }
            throw new MatchError(tuple32);
        }
        throw new MatchError(tuple3);
    }

    public boolean bindExec$default$5() {
        return false;
    }

    public Set<String> bindExec$default$6() {
        return null;
    }

    public Executor bindExec$default$7() {
        return null;
    }

    public scala.collection.mutable.Map<String, NDArray> bindExec$default$8() {
        return null;
    }

    public ListMap<String, Enumeration.Value> bindExec$default$9() {
        return null;
    }

    private ExecutorManager$() {
        MODULE$ = this;
    }
}

