package org.platanios.tensorflow.api.ops.control_flow;

import org.platanios.tensorflow.api.core.Shape$;
import org.platanios.tensorflow.api.core.package$exception$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Gradients$Registry$;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Op$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.OutputIndexedSlices;
import org.platanios.tensorflow.api.ops.OutputLike;
import org.platanios.tensorflow.api.ops.SparseOutput;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.DataType;
import org.platanios.tensorflow.api.types.SupportedType$;
import org.platanios.tensorflow.api.types.package$RESOURCE$;
import org.platanios.tensorflow.jni.UnimplementedException;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.mutable.SetLike;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.Null$;
import scala.runtime.RichInt$;

/* compiled from: ControlFlow.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/control_flow/ControlFlow$Gradients$.class */
public class ControlFlow$Gradients$ {
    public static ControlFlow$Gradients$ MODULE$;

    static {
        new ControlFlow$Gradients$();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> loopCondGradient(Op op, Seq<OutputLike> seq) {
        return Seq$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Null$[]{null}));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> nextIterationGradient(Op op, Seq<OutputLike> seq) {
        return seq;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> enterGradient(Op op, Seq<OutputLike> seq) {
        return (Seq) Op$.MODULE$.currentControlFlowContext().map(context -> {
            SetLike $plus$eq;
            if (context.backPropagate() && !context.gradientLoopState().isEmpty()) {
                if (op.booleanAttribute("is_constant")) {
                    return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{((WhileLoopContext) context).addBackwardAccumulator(op, (OutputLike) seq.head())}));
                }
                WhileLoopContext whileLoopContext = (WhileLoopContext) context;
                Seq<OutputLike> seq2 = (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{ControlFlow$.MODULE$.exit((OutputLike) seq.head(), ControlFlow$.MODULE$.exit$default$2())}));
                OutputLike outputLike = (OutputLike) seq2.apply(0);
                if (outputLike instanceof Output) {
                    $plus$eq = whileLoopContext.loopExits().$plus$eq((Output) outputLike);
                } else if (outputLike instanceof OutputIndexedSlices) {
                    OutputIndexedSlices outputIndexedSlices = (OutputIndexedSlices) outputLike;
                    whileLoopContext.loopExits().$plus$eq(outputIndexedSlices.indices());
                    whileLoopContext.loopExits().$plus$eq(outputIndexedSlices.values());
                    $plus$eq = outputIndexedSlices.denseShape() != null ? whileLoopContext.loopExits().$plus$eq(outputIndexedSlices.denseShape()) : BoxedUnit.UNIT;
                } else {
                    if (!(outputLike instanceof SparseOutput)) {
                        throw new MatchError(outputLike);
                    }
                    SparseOutput sparseOutput = (SparseOutput) outputLike;
                    whileLoopContext.loopExits().$plus$eq(sparseOutput.indices());
                    whileLoopContext.loopExits().$plus$eq(sparseOutput.values());
                    $plus$eq = sparseOutput.denseShape() != null ? whileLoopContext.loopExits().$plus$eq(sparseOutput.denseShape()) : BoxedUnit.UNIT;
                }
                context.exitResult(seq2);
                return seq2;
            }
            return seq;
        }).get();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> exitGradient(Op op, Seq<OutputLike> seq) throws UnimplementedException {
        return (Seq) Op$.MODULE$.currentControlFlowContext().map(context -> {
            SetLike $plus$eq;
            SetLike $plus$eq2;
            if (!context.backPropagate()) {
                return Seq$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Null$[]{null}));
            }
            if (op.controlFlowContext().flatMap(context -> {
                return context.gradientLoopState();
            }).isDefined()) {
                throw package$exception$.MODULE$.UnimplementedException().apply("Second-order gradients are not supported for while loops.");
            }
            OutputLike outputLike = (OutputLike) seq.head();
            if (outputLike instanceof Output) {
                $plus$eq = context.values().$plus$eq(((Output) outputLike).name());
            } else if (outputLike instanceof OutputIndexedSlices) {
                OutputIndexedSlices outputIndexedSlices = (OutputIndexedSlices) outputLike;
                context.values().$plus$eq(outputIndexedSlices.indices().name());
                context.values().$plus$eq(outputIndexedSlices.values().name());
                $plus$eq = outputIndexedSlices.denseShape() != null ? context.values().$plus$eq(outputIndexedSlices.denseShape().name()) : BoxedUnit.UNIT;
            } else {
                if (!(outputLike instanceof SparseOutput)) {
                    throw new MatchError(outputLike);
                }
                SparseOutput sparseOutput = (SparseOutput) outputLike;
                context.values().$plus$eq(sparseOutput.indices().name());
                context.values().$plus$eq(sparseOutput.values().name());
                $plus$eq = sparseOutput.denseShape() != null ? context.values().$plus$eq(sparseOutput.denseShape().name()) : BoxedUnit.UNIT;
            }
            WhileLoopContext whileLoopContext = (WhileLoopContext) context;
            context.enter();
            Seq apply = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{ControlFlow$.MODULE$.enter((OutputLike) seq.head(), context.name(), false, whileLoopContext.parallelIterations(), ControlFlow$.MODULE$.enter$default$5(), "ExitGradient")}));
            OutputLike outputLike2 = (OutputLike) apply.apply(0);
            if (outputLike2 instanceof Output) {
                $plus$eq2 = whileLoopContext.loopEnters().$plus$eq((Output) outputLike2);
            } else if (outputLike2 instanceof OutputIndexedSlices) {
                OutputIndexedSlices outputIndexedSlices2 = (OutputIndexedSlices) outputLike2;
                whileLoopContext.loopEnters().$plus$eq(outputIndexedSlices2.indices());
                whileLoopContext.loopEnters().$plus$eq(outputIndexedSlices2.values());
                $plus$eq2 = outputIndexedSlices2.denseShape() != null ? whileLoopContext.loopEnters().$plus$eq(outputIndexedSlices2.denseShape()) : BoxedUnit.UNIT;
            } else {
                if (!(outputLike2 instanceof SparseOutput)) {
                    throw new MatchError(outputLike2);
                }
                SparseOutput sparseOutput2 = (SparseOutput) outputLike2;
                whileLoopContext.loopEnters().$plus$eq(sparseOutput2.indices());
                whileLoopContext.loopEnters().$plus$eq(sparseOutput2.values());
                $plus$eq2 = sparseOutput2.denseShape() != null ? whileLoopContext.loopEnters().$plus$eq(sparseOutput2.denseShape()) : BoxedUnit.UNIT;
            }
            context.exit();
            return apply;
        }).get();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v114, types: [org.platanios.tensorflow.api.ops.Output] */
    /* JADX WARN: Type inference failed for: r0v96, types: [org.platanios.tensorflow.api.ops.OutputIndexedSlices] */
    public Seq<OutputLike> switchGradient(Op op, Seq<OutputLike> seq) {
        Seq<OutputLike> apply;
        Seq<OutputLike> apply2;
        Seq<OutputLike> seq2;
        SparseOutput sparseOutput;
        Option<Context> currentControlFlowContext = Op$.MODULE$.currentControlFlowContext();
        boolean z = false;
        Some some = null;
        Option<Context> controlFlowContext = op.controlFlowContext();
        if (controlFlowContext instanceof Some) {
            z = true;
            some = (Some) controlFlowContext;
            Context context = (Context) some.value();
            if (context instanceof CondContext) {
                CondContext condContext = (CondContext) context;
                if (seq.apply(1 - condContext.branch().value()) != null) {
                    seq2 = (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) ControlFlow$.MODULE$.merge(seq, "CondGradient")._1(), null}));
                } else {
                    DataType dataType = op.inputs()[0].dataType();
                    package$RESOURCE$ package_resource_ = package$RESOURCE$.MODULE$;
                    if (dataType != null ? !dataType.equals(package_resource_) : package_resource_ != null) {
                        seq2 = (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Null$[]{null, null}));
                    } else {
                        OutputLike outputLike = (OutputLike) seq.apply(condContext.branch().value());
                        if (outputLike instanceof Output) {
                            sparseOutput = Basic$.MODULE$.zerosLike((Output) outputLike, Basic$.MODULE$.zerosLike$default$2(), Basic$.MODULE$.zerosLike$default$3(), Basic$.MODULE$.zerosLike$default$4());
                        } else if (outputLike instanceof OutputIndexedSlices) {
                            OutputIndexedSlices outputIndexedSlices = (OutputIndexedSlices) outputLike;
                            sparseOutput = new OutputIndexedSlices(Basic$.MODULE$.zeros(outputIndexedSlices.indices().dataType(), Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1})), TensorConvertible$.MODULE$.fromShape()), Basic$.MODULE$.zeros$default$3()), Basic$.MODULE$.zeros(outputIndexedSlices.values().dataType(), Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, outputIndexedSlices.values().shape().apply(1)})), TensorConvertible$.MODULE$.fromShape()), Basic$.MODULE$.zeros$default$3()), outputIndexedSlices.denseShape());
                        } else {
                            if (!(outputLike instanceof SparseOutput)) {
                                throw new MatchError(outputLike);
                            }
                            SparseOutput sparseOutput2 = (SparseOutput) outputLike;
                            sparseOutput = new SparseOutput(Basic$.MODULE$.zeros(sparseOutput2.indices().dataType(), Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1, sparseOutput2.indices().shape().apply(1)})), TensorConvertible$.MODULE$.fromShape()), Basic$.MODULE$.zeros$default$3()), Basic$.MODULE$.zeros(sparseOutput2.values().dataType(), Implicits$.MODULE$.tensorConvertibleToOutput(Shape$.MODULE$.apply((Seq<Object>) Predef$.MODULE$.wrapIntArray(new int[]{1})), TensorConvertible$.MODULE$.fromShape()), Basic$.MODULE$.zeros$default$3()), sparseOutput2.denseShape());
                        }
                        OutputLike selectSwitchResult = condContext.branch().other().selectSwitchResult(ControlFlow$.MODULE$.colocatedSwitch(sparseOutput, condContext.predicate(), ControlFlow$.MODULE$.colocatedSwitch$default$3()));
                        seq2 = condContext.branch().value() == 0 ? (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) ControlFlow$.MODULE$.merge(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{outputLike, selectSwitchResult})), "CondGradient")._1(), null})) : (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) ControlFlow$.MODULE$.merge(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{selectSwitchResult, outputLike})), "CondGradient")._1(), null}));
                    }
                }
                apply = seq2;
                return apply;
            }
        }
        if (z && (some.value() instanceof WhileLoopContext)) {
            Some flatMap = currentControlFlowContext.flatMap(context2 -> {
                return context2.gradientLoopState();
            }).flatMap(gradientLoopState -> {
                return gradientLoopState.switchMap().get(op);
            });
            if (flatMap instanceof Some) {
                OutputLike outputLike2 = (OutputLike) flatMap.value();
                if (seq.apply(1) != null) {
                    WhileLoopContext$.MODULE$.addNextIterationAndBackEdge(outputLike2, (OutputLike) seq.apply(1), false);
                } else {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                }
                apply2 = Seq$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Null$[]{null, null}));
            } else if (!None$.MODULE$.equals(flatMap) || seq.head() == null) {
                apply2 = Seq$.MODULE$.apply(Predef$.MODULE$.genericWrapArray(new Null$[]{null, null}));
            } else {
                OutputLike outputLike3 = (OutputLike) ControlFlow$.MODULE$.merge(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) seq.apply(0), (OutputLike) seq.apply(0)})), "SwitchGradient")._1();
                currentControlFlowContext.flatMap(context3 -> {
                    return context3.gradientLoopState();
                }).map(gradientLoopState2 -> {
                    return gradientLoopState2.switchMap();
                }).foreach(map -> {
                    return map.$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(op), outputLike3));
                });
                apply2 = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{outputLike3, null}));
            }
            apply = apply2;
        } else {
            apply = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) ControlFlow$.MODULE$.merge(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) ControlFlow$.MODULE$.m402switch((OutputLike) seq.apply(0), op.inputs()[1], ControlFlow$.MODULE$.switch$default$3())._1(), (OutputLike) ControlFlow$.MODULE$.m402switch((OutputLike) seq.apply(1), op.inputs()[1], ControlFlow$.MODULE$.switch$default$3())._2()})), ControlFlow$.MODULE$.merge$default$2())._1(), null}));
        }
        return apply;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> mergeGradient(Op op, Seq<OutputLike> seq) {
        Seq<OutputLike> seq2;
        Option<Context> currentControlFlowContext = Op$.MODULE$.currentControlFlowContext();
        boolean z = false;
        Some some = null;
        Option<Context> outputContext = ControlFlow$.MODULE$.getOutputContext(op.inputs()[0].op());
        if (outputContext instanceof Some) {
            z = true;
            some = (Some) outputContext;
            Context context = (Context) some.value();
            if (context instanceof CondContext) {
                CondContext condContext = (CondContext) context;
                Tuple2 colocatedSwitch = ControlFlow$.MODULE$.colocatedSwitch((OutputLike) seq.head(), (Output) currentControlFlowContext.flatMap(context2 -> {
                    return context2.gradientLoopState();
                }).map(gradientLoopState -> {
                    return (Output) gradientLoopState.historyMap().getOrElse(condContext.predicate().name(), () -> {
                        gradientLoopState.backwardContext().exit();
                        Output addForwardAccumulator = gradientLoopState.addForwardAccumulator(condContext.predicate(), gradientLoopState.addForwardAccumulator$default$2());
                        gradientLoopState.backwardContext().enter();
                        Output addBackwardAccumulatedValue = gradientLoopState.addBackwardAccumulatedValue(addForwardAccumulator, condContext.predicate(), gradientLoopState.addBackwardAccumulatedValue$default$3());
                        gradientLoopState.historyMap().$plus$eq(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(condContext.predicate().name()), addBackwardAccumulatedValue));
                        return addBackwardAccumulatedValue;
                    });
                }).getOrElse(() -> {
                    return condContext.predicate();
                }), ControlFlow$.MODULE$.colocatedSwitch$default$3());
                seq2 = (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) colocatedSwitch._1(), (OutputLike) colocatedSwitch._2()}));
                return seq2;
            }
        }
        if (z && (some.value() instanceof WhileLoopContext)) {
            Tuple2 colocatedSwitch2 = ControlFlow$.MODULE$.colocatedSwitch((OutputLike) seq.head(), ((WhileLoopContext) currentControlFlowContext.get()).pivot(), ControlFlow$.MODULE$.colocatedSwitch$default$3());
            seq2 = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) colocatedSwitch2._1(), (OutputLike) colocatedSwitch2._2()}));
        } else {
            seq2 = (Seq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), op.numInputs()).map(obj -> {
                return $anonfun$mergeGradient$5(op, seq, BoxesRunTime.unboxToInt(obj));
            }, IndexedSeq$.MODULE$.canBuildFrom());
        }
        return seq2;
    }

    public static final /* synthetic */ OutputLike $anonfun$mergeGradient$5(Op op, Seq seq, int i) {
        return (OutputLike) ControlFlow$.MODULE$.colocatedSwitch((OutputLike) seq.head(), Math$.MODULE$.equal(op.outputs()[1], Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(i), TensorConvertible$.MODULE$.fromSupportedType(SupportedType$.MODULE$.intIsSupported())), Math$.MODULE$.equal$default$3()), ControlFlow$.MODULE$.colocatedSwitch$default$3())._2();
    }

    public ControlFlow$Gradients$() {
        MODULE$ = this;
        Gradients$Registry$.MODULE$.registerNonDifferentiable("ControlTrigger");
        Gradients$Registry$.MODULE$.register("LoopCond", (op, seq) -> {
            return MODULE$.loopCondGradient(op, seq);
        });
        Gradients$Registry$.MODULE$.register("NextIteration", (op2, seq2) -> {
            return MODULE$.nextIterationGradient(op2, seq2);
        });
        Gradients$Registry$.MODULE$.register("RefNextIteration", (op3, seq3) -> {
            return MODULE$.nextIterationGradient(op3, seq3);
        });
        Gradients$Registry$.MODULE$.register("Enter", (op4, seq4) -> {
            return MODULE$.enterGradient(op4, seq4);
        });
        Gradients$Registry$.MODULE$.register("RefEnter", (op5, seq5) -> {
            return MODULE$.enterGradient(op5, seq5);
        });
        Gradients$Registry$.MODULE$.register("Exit", (op6, seq6) -> {
            return MODULE$.exitGradient(op6, seq6);
        });
        Gradients$Registry$.MODULE$.register("RefExit", (op7, seq7) -> {
            return MODULE$.exitGradient(op7, seq7);
        });
        Gradients$Registry$.MODULE$.register("Switch", (op8, seq8) -> {
            return MODULE$.switchGradient(op8, seq8);
        });
        Gradients$Registry$.MODULE$.register("RefSwitch", (op9, seq9) -> {
            return MODULE$.switchGradient(op9, seq9);
        });
        Gradients$Registry$.MODULE$.register("Merge", (op10, seq10) -> {
            return MODULE$.mergeGradient(op10, seq10);
        });
        Gradients$Registry$.MODULE$.register("RefMerge", (op11, seq11) -> {
            return MODULE$.mergeGradient(op11, seq11);
        });
    }
}
