package org.platanios.tensorflow.api.ops.control_flow;

import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.OutputIndexedSlices;
import org.platanios.tensorflow.api.ops.SparseOutput;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.DataType;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterable$;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.mutable.ListBuffer;
import scala.collection.mutable.Map;
import scala.collection.mutable.Map$;
import scala.collection.mutable.Set;
import scala.collection.mutable.Set$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: GradientState.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ua!B\u0001\u0003\u0001\u0011q!!D$sC\u0012LWM\u001c;Ti\u0006$XM\u0003\u0002\u0004\t\u0005a1m\u001c8ue>dwL\u001a7po*\u0011QAB\u0001\u0004_B\u001c(BA\u0004\t\u0003\r\t\u0007/\u001b\u0006\u0003\u0013)\t!\u0002^3og>\u0014h\r\\8x\u0015\tYA\"A\u0005qY\u0006$\u0018M\\5pg*\tQ\"A\u0002pe\u001e\u001c\"\u0001A\b\u0011\u0005A\u0019R\"A\t\u000b\u0003I\tQa]2bY\u0006L!\u0001F\t\u0003\r\u0005s\u0017PU3g\u0011\u00191\u0002\u0001\"\u0001\u0003/\u00051A(\u001b8jiz\"\u0012\u0001\u0007\t\u00033\u0001i\u0011AA\u0002\u0001\u0011\u0019a\u0002\u0001)A\u0005;\u0005\u0019Q.\u00199\u0011\ty\u0019S\u0005K\u0007\u0002?)\u0011\u0001%I\u0001\b[V$\u0018M\u00197f\u0015\t\u0011\u0013#\u0001\u0006d_2dWm\u0019;j_:L!\u0001J\u0010\u0003\u00075\u000b\u0007\u000f\u0005\u0002\u001aM%\u0011qE\u0001\u0002\b\u0007>tG/\u001a=u!\tI\u0012&\u0003\u0002+\u0005\t\trI]1eS\u0016tG\u000fT8paN#\u0018\r^3\t\r1\u0002A\u0011\u0001\u0003.\u0003Q9W\r^$sC\u0012LWM\u001c;M_>\u00048\u000b^1uKR\u0019a&M\u001c\u0011\u0007Ay\u0003&\u0003\u00021#\t1q\n\u001d;j_:DQAM\u0016A\u0002M\n!a\u001c9\u0011\u0005Q*T\"\u0001\u0003\n\u0005Y\"!AA(q\u0011\u0015A4\u00061\u0001:\u0003\u0019\u0011WMZ8sKB\u0011\u0001CO\u0005\u0003wE\u0011qAQ8pY\u0016\fg\u000e\u0003\u0004>\u0001\u0011\u0005AAP\u0001\u001eK:$XM]$sC\u0012LWM\u001c;XQ&dW\rT8pa\u000e{g\u000e^3yiR\u0019qHQ\"\u0011\u0005A\u0001\u0015BA!\u0012\u0005\u0011)f.\u001b;\t\u000bIb\u0004\u0019A\u001a\t\u000bab\u0004\u0019A\u001d\t\r\u0015\u0003A\u0011\u0001\u0003G\u0003q)\u00070\u001b;He\u0006$\u0017.\u001a8u/\"LG.\u001a'p_B\u001cuN\u001c;fqR$2aP$I\u0011\u0015\u0011D\t1\u00014\u0011\u0015AD\t1\u0001:\u0011\u0019Q\u0005\u0001\"\u0001\u0003\u0017\u0006\u0019\u0012\r\u001a3XQ&dW\rT8pa\u000e{g\u000e^3yiR!q\bT'S\u0011\u0015\u0011\u0014\n1\u00014\u0011\u0015q\u0015\n1\u0001P\u0003\u001d\u0011W\r^<fK:\u00042A\b)4\u0013\t\tvDA\u0002TKRDQaU%A\u0002Q\u000b1BY3uo\u0016,g\u000eT5tiB\u0019a$V\u001a\n\u0005Y{\"A\u0003'jgR\u0014UO\u001a4fe\"1\u0001\f\u0001C\u0001\te\u000b\u0001C_3s_Nd\u0015n[3G_J,\u00050\u001b;\u0015\u0005ik\u0006C\u0001\u001b\\\u0013\taFA\u0001\u0004PkR\u0004X\u000f\u001e\u0005\u0006=^\u0003\rAW\u0001\u0006m\u0006dW/\u001a\u0005\u0007A\u0002!\t\u0001B1\u0002\u0013i,'o\\:MS.,Gc\u00012dIB\u0019\u0001c\f.\t\u000bIz\u0006\u0019A\u001a\t\u000b\u0015|\u0006\u0019\u00014\u0002\u000b%tG-\u001a=\u0011\u0005A9\u0017B\u00015\u0012\u0005\rIe\u000e\u001e\u0005\u0007U\u0002!\t\u0001B6\u0002-A\u0014xnY3tgVsWo]3e\u0019>|\u0007/\u0012=jiN$2\u0001\u001c<z!\riGO\u0017\b\u0003]J\u0004\"a\\\t\u000e\u0003AT!!\u001d\u000e\u0002\rq\u0012xn\u001c;?\u0013\t\u0019\u0018#\u0001\u0004Qe\u0016$WMZ\u0005\u0003#VT!a]\t\t\u000b]L\u0007\u0019\u0001=\u0002\u001bA,g\u000eZ5oO\u000e{WO\u001c;t!\u0011q2e\r4\t\u000biL\u0007\u0019A>\u0002\u001d\u0011,7\u000f^5oCRLwN\\(qgB\u0019Q\u000e^\u001a\t\ru\u0004A\u0011\u0001\u0003\u007f\u0003-\u0001xn\u001d;Qe>\u001cWm]:\u0015\u0003}:q!!\u0001\u0003\u0011\u0003\t\u0019!A\u0007He\u0006$\u0017.\u001a8u'R\fG/\u001a\t\u00043\u0005\u0015aAB\u0001\u0003\u0011\u0003\t9aE\u0002\u0002\u0006=AqAFA\u0003\t\u0003\tY\u0001\u0006\u0002\u0002\u0004!I\u0011qBA\u0003\t\u0003!\u0011\u0011C\u0001\f[\u0006L(-Z\"sK\u0006$X\r\u0006\u0005\u0002\u0014\u0005U\u0011qCA\r!\r\u0001r\u0006\u0007\u0005\u0007\u001d\u00065\u0001\u0019A(\t\rM\u000bi\u00011\u0001U\u0011\u001d\tY\"!\u0004A\u0002e\n\u0001dY8m_\u000e\fG/Z$sC\u0012LWM\u001c;t/&$\bn\u00149t\u0001")
/* loaded from: input_file:org/platanios/tensorflow/api/ops/control_flow/GradientState.class */
public class GradientState {
    private final Map<Context, GradientLoopState> map = Map$.MODULE$.empty();

    public Option<GradientLoopState> getGradientLoopState(Op op, boolean z) {
        return ((z && ControlFlow$.MODULE$.isLoopExit(op)) ? op.controlFlowContext().flatMap(context -> {
            return context.outerContext();
        }).flatMap(context2 -> {
            return context2.whileLoopContext(context2.whileLoopContext$default$1());
        }) : WhileLoopContext$.MODULE$.getWhileLoopContext(op)).flatMap(context3 -> {
            return this.map.get(context3);
        });
    }

    public void enterGradientWhileLoopContext(Op op, boolean z) {
        getGradientLoopState(op, z).foreach(gradientLoopState -> {
            $anonfun$enterGradientWhileLoopContext$1(gradientLoopState);
            return BoxedUnit.UNIT;
        });
    }

    public void exitGradientWhileLoopContext(Op op, boolean z) {
        getGradientLoopState(op, z).foreach(gradientLoopState -> {
            $anonfun$exitGradientWhileLoopContext$1(gradientLoopState);
            return BoxedUnit.UNIT;
        });
    }

    public void addWhileLoopContext(Op op, Set<Op> set, ListBuffer<Op> listBuffer) {
        WhileLoopContext$.MODULE$.getWhileLoopContext(op).foreach(whileLoopContext -> {
            $anonfun$addWhileLoopContext$1(this, set, listBuffer, whileLoopContext);
            return BoxedUnit.UNIT;
        });
    }

    public Output zerosLikeForExit(Output output) {
        Output zerosLike;
        Output output2;
        Output output3;
        Option<Context> controlFlowContext = output.op().controlFlowContext();
        Some flatMap = controlFlowContext.flatMap(context -> {
            return context.outerContext();
        }).flatMap(context2 -> {
            return context2.whileLoopContext(context2.whileLoopContext$default$1());
        }).flatMap(context3 -> {
            return this.map.get(context3);
        });
        if (flatMap instanceof Some) {
            GradientLoopState gradientLoopState = (GradientLoopState) flatMap.value();
            if (output.shape().isFullyDefined()) {
                gradientLoopState.backwardContext().enter();
                Output zeros = Basic$.MODULE$.zeros(output.dataType(), org.platanios.tensorflow.api.package$.MODULE$.tensorConvertibleToOutput(output.shape(), TensorConvertible$.MODULE$.shapeTensorConvertible()), Basic$.MODULE$.zeros$default$3());
                gradientLoopState.backwardContext().exit();
                output3 = zeros;
            } else {
                controlFlowContext.flatMap(context4 -> {
                    return context4.outerContext();
                }).foreach(context5 -> {
                    context5.enter();
                    return BoxedUnit.UNIT;
                });
                Output shape = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), false, Basic$.MODULE$.shape$default$4());
                controlFlowContext.flatMap(context6 -> {
                    return context6.outerContext();
                }).foreach(context7 -> {
                    context7.exit();
                    return BoxedUnit.UNIT;
                });
                Output addForwardAccumulator = gradientLoopState.addForwardAccumulator(shape, gradientLoopState.addForwardAccumulator$default$2());
                gradientLoopState.backwardContext().enter();
                Output zeros2 = Basic$.MODULE$.zeros(output.dataType(), gradientLoopState.addBackwardAccumulatedValue(addForwardAccumulator, shape, gradientLoopState.addBackwardAccumulatedValue$default$3()), Basic$.MODULE$.zeros$default$3());
                gradientLoopState.backwardContext().exit();
                output3 = zeros2;
            }
            output2 = output3;
        } else {
            if (!None$.MODULE$.equals(flatMap)) {
                throw new MatchError(flatMap);
            }
            if (output.shape().isFullyDefined()) {
                zerosLike = Basic$.MODULE$.zeros(output.dataType(), org.platanios.tensorflow.api.package$.MODULE$.tensorConvertibleToOutput(output.shape(), TensorConvertible$.MODULE$.shapeTensorConvertible()), Basic$.MODULE$.zeros$default$3());
            } else {
                zerosLike = Basic$.MODULE$.zerosLike(output, Basic$.MODULE$.zerosLike$default$2(), false, Basic$.MODULE$.zerosLike$default$4());
            }
            output2 = zerosLike;
        }
        return output2;
    }

    public Option<Output> zerosLike(Op op, int i) {
        Some some;
        Option some2;
        Some some3;
        Some some4;
        Tuple2 tuple2;
        if (ControlFlow$.MODULE$.isLoopSwitch(op)) {
            return None$.MODULE$;
        }
        boolean isSwitch = ControlFlow$.MODULE$.isSwitch(op);
        Some flatMap = WhileLoopContext$.MODULE$.getWhileLoopContext(op).flatMap(context -> {
            return this.map.get(context);
        });
        if (flatMap instanceof Some) {
            GradientLoopState gradientLoopState = (GradientLoopState) flatMap.value();
            Output output = op.outputs()[i];
            if (output.shape().isFullyDefined()) {
                Output zeros = Basic$.MODULE$.zeros(output.dataType(), org.platanios.tensorflow.api.package$.MODULE$.tensorConvertibleToOutput(output.shape(), TensorConvertible$.MODULE$.shapeTensorConvertible()), Basic$.MODULE$.zeros$default$3());
                if (isSwitch) {
                    Some flatMap2 = op.controlFlowContext().flatMap(context2 -> {
                        CondContext condContext = (CondContext) context2;
                        return gradientLoopState.historyMap().get(condContext.predicate().name()).map(output2 -> {
                            return new Tuple2(output2, condContext.branch());
                        });
                    });
                    if ((flatMap2 instanceof Some) && (tuple2 = (Tuple2) flatMap2.value()) != null) {
                        some4 = new Some(((CondBranch) tuple2._2()).other().selectSwitchResult(ControlFlow$.MODULE$.colocatedSwitch(zeros, (Output) tuple2._1(), ControlFlow$.MODULE$.colocatedSwitch$default$3())));
                    } else {
                        if (!None$.MODULE$.equals(flatMap2)) {
                            throw new MatchError(flatMap2);
                        }
                        some4 = new Some(zeros);
                    }
                    some3 = some4;
                } else {
                    some3 = new Some(zeros);
                }
            } else {
                if (isSwitch) {
                    some2 = op.controlFlowContext().map(context3 -> {
                        CondContext condContext = (CondContext) context3;
                        condContext.outerContext().foreach(context3 -> {
                            context3.enter();
                            return BoxedUnit.UNIT;
                        });
                        Output output2 = (Output) condContext.branch().other().selectSwitchResult(ControlFlow$.MODULE$.colocatedSwitch(op.inputs()[0], condContext.predicate(), ControlFlow$.MODULE$.colocatedSwitch$default$3()));
                        Output shape = Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), false, Basic$.MODULE$.shape$default$4());
                        condContext.outerContext().foreach(context4 -> {
                            context4.exit();
                            return BoxedUnit.UNIT;
                        });
                        output2.op().controlFlowContext_$eq(new Some(condContext));
                        shape.op().controlFlowContext_$eq(new Some(condContext));
                        return shape;
                    });
                } else {
                    op.controlFlowContext().foreach(context4 -> {
                        context4.enter();
                        return BoxedUnit.UNIT;
                    });
                    Output shape = Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), false, Basic$.MODULE$.shape$default$4());
                    op.controlFlowContext().foreach(context5 -> {
                        context5.exit();
                        return BoxedUnit.UNIT;
                    });
                    some2 = new Some(shape);
                }
                Option option = some2;
                gradientLoopState.backwardContext().exit();
                Output addForwardAccumulator = gradientLoopState.addForwardAccumulator((Output) option.get(), isSwitch);
                gradientLoopState.backwardContext().enter();
                some3 = new Some(Basic$.MODULE$.zeros(output.dataType(), gradientLoopState.addBackwardAccumulatedValue(addForwardAccumulator, (Output) option.get(), isSwitch), Basic$.MODULE$.zeros$default$3()));
            }
            some = some3;
        } else {
            if (!None$.MODULE$.equals(flatMap)) {
                throw new MatchError(flatMap);
            }
            some = new Some(Context$.MODULE$.zerosLikeOutsideLoop(op, i));
        }
        return some;
    }

    public scala.collection.immutable.Set<Output> processUnusedLoopExits(Map<Op, Object> map, scala.collection.immutable.Set<Op> set) {
        Set empty = Set$.MODULE$.empty();
        this.map.values().foreach(gradientLoopState -> {
            $anonfun$processUnusedLoopExits$1(map, set, empty, gradientLoopState);
            return BoxedUnit.UNIT;
        });
        return empty.toSet();
    }

    public void postProcess() {
        this.map.values().foreach(gradientLoopState -> {
            $anonfun$postProcess$1(gradientLoopState);
            return BoxedUnit.UNIT;
        });
    }

    public static final /* synthetic */ void $anonfun$enterGradientWhileLoopContext$1(GradientLoopState gradientLoopState) {
        gradientLoopState.backwardContext().enter();
    }

    public static final /* synthetic */ void $anonfun$exitGradientWhileLoopContext$1(GradientLoopState gradientLoopState) {
        gradientLoopState.backwardContext().exit();
    }

    public static final /* synthetic */ boolean $anonfun$addWhileLoopContext$4(Set set, Output output) {
        return !set.contains(output.op());
    }

    public static final /* synthetic */ void $anonfun$addWhileLoopContext$1(GradientState gradientState, Set set, ListBuffer listBuffer, WhileLoopContext whileLoopContext) {
        if (gradientState.map.contains(whileLoopContext)) {
            return;
        }
        GradientLoopState apply = GradientLoopState$.MODULE$.apply(whileLoopContext, whileLoopContext.outerContext().flatMap(context -> {
            return context.whileLoopContext(context.whileLoopContext$default$1());
        }).flatMap(context2 -> {
            return gradientState.map.get(context2);
        }));
        gradientState.map.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(whileLoopContext), apply));
        ((IterableLike) apply.forwardLoopExits().filter(output -> {
            return BoxesRunTime.boxToBoolean($anonfun$addWhileLoopContext$4(set, output));
        })).foreach(output2 -> {
            set.$plus$eq(output2.op());
            return listBuffer.$plus$eq(output2.op());
        });
    }

    public static final /* synthetic */ boolean $anonfun$processUnusedLoopExits$2(Map map, Output output) {
        return BoxesRunTime.unboxToInt(map.getOrElse(output.op(), () -> {
            return 0;
        })) == 0;
    }

    public static final /* synthetic */ boolean $anonfun$processUnusedLoopExits$6(Map map, Op op) {
        return BoxesRunTime.unboxToInt(map.getOrElse(op, () -> {
            return 0;
        })) == 0;
    }

    public static final /* synthetic */ void $anonfun$processUnusedLoopExits$8(Map map, Op op) {
        map.update(op, BoxesRunTime.boxToInteger(1));
    }

    public static final /* synthetic */ void $anonfun$processUnusedLoopExits$1(Map map, scala.collection.immutable.Set set, Set set2, GradientLoopState gradientLoopState) {
        ((IterableLike) gradientLoopState.forwardLoopExits().filter(output -> {
            return BoxesRunTime.boxToBoolean($anonfun$processUnusedLoopExits$2(map, output));
        })).foreach(output2 -> {
            gradientLoopState.pendingExitsCount_$eq(gradientLoopState.pendingExitsCount() - 1);
            if (set.contains(output2.op())) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                gradientLoopState.unusedExits().$plus$eq(output2);
            }
            return gradientLoopState.pendingExitsCount() == 0 ? set2.$plus$plus$eq(gradientLoopState.unusedExits()) : BoxedUnit.UNIT;
        });
        ((IterableLike) ((TraversableLike) gradientLoopState.forwardContext().loopEnters().map(output3 -> {
            return output3.op();
        }, Set$.MODULE$.canBuildFrom())).filter(op -> {
            return BoxesRunTime.boxToBoolean($anonfun$processUnusedLoopExits$6(map, op));
        })).foreach(op2 -> {
            $anonfun$processUnusedLoopExits$8(map, op2);
            return BoxedUnit.UNIT;
        });
    }

    public static final /* synthetic */ boolean $anonfun$postProcess$3(Output output) {
        Output output2 = output.op().inputs()[0];
        Output output3 = output.op().inputs()[1];
        return output2 != null ? output2.equals(output3) : output3 == null;
    }

    public static final /* synthetic */ void $anonfun$postProcess$4(GradientLoopState gradientLoopState, Output output) {
        Output output2;
        DataType dataType = output.op().inputs()[0].dataType();
        Shape shape = output.op().inputs()[0].shape();
        if (shape.isFullyDefined()) {
            gradientLoopState.backwardContext().enter();
            Output output3 = (Output) ControlFlow$.MODULE$.nextIteration(Basic$.MODULE$.zeros(dataType, org.platanios.tensorflow.api.package$.MODULE$.tensorConvertibleToOutput(shape, TensorConvertible$.MODULE$.shapeTensorConvertible()), Basic$.MODULE$.zeros$default$3()), ControlFlow$.MODULE$.nextIteration$default$2());
            gradientLoopState.backwardContext().exit();
            output2 = output3;
        } else {
            Option<Context> outerContext = gradientLoopState.backwardContext().outerContext();
            outerContext.foreach(context -> {
                context.enter();
                return BoxedUnit.UNIT;
            });
            Output zeros = Basic$.MODULE$.zeros(dataType, Basic$.MODULE$.shape(output.op().inputs()[0].op().inputs()[0], Basic$.MODULE$.shape$default$2(), false, Basic$.MODULE$.shape$default$4()), Basic$.MODULE$.zeros$default$3());
            outerContext.foreach(context2 -> {
                context2.exit();
                return BoxedUnit.UNIT;
            });
            gradientLoopState.backwardContext().enter();
            Output output4 = (Output) ControlFlow$.MODULE$.nextIteration(zeros, ControlFlow$.MODULE$.nextIteration$default$2());
            gradientLoopState.backwardContext().exit();
            output2 = output4;
        }
        ControlFlow$.MODULE$.updateInput(output.op(), 1, output2);
    }

    public static final /* synthetic */ void $anonfun$postProcess$1(GradientLoopState gradientLoopState) {
        ((IterableLike) ((TraversableLike) gradientLoopState.switchMap().values().flatMap(outputLike -> {
            Seq apply;
            if (outputLike instanceof Output) {
                apply = (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{(Output) outputLike}));
            } else if (outputLike instanceof OutputIndexedSlices) {
                OutputIndexedSlices outputIndexedSlices = (OutputIndexedSlices) outputLike;
                apply = outputIndexedSlices.denseShape() == null ? (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{outputIndexedSlices.indices(), outputIndexedSlices.values()})) : (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{outputIndexedSlices.indices(), outputIndexedSlices.values(), outputIndexedSlices.denseShape()}));
            } else {
                if (!(outputLike instanceof SparseOutput)) {
                    throw new MatchError(outputLike);
                }
                SparseOutput sparseOutput = (SparseOutput) outputLike;
                apply = sparseOutput.denseShape() == null ? (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{sparseOutput.indices(), sparseOutput.values()})) : Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{sparseOutput.indices(), sparseOutput.values(), sparseOutput.denseShape()}));
            }
            return apply;
        }, Iterable$.MODULE$.canBuildFrom())).filter(output -> {
            return BoxesRunTime.boxToBoolean($anonfun$postProcess$3(output));
        })).foreach(output2 -> {
            $anonfun$postProcess$4(gradientLoopState, output2);
            return BoxedUnit.UNIT;
        });
    }
}
