package org.platanios.tensorflow.api.ops;

import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.SupportedType$;
import scala.Array$;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

/* compiled from: DataFlow.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/DataFlow$Gradients$.class */
public class DataFlow$Gradients$ {
    public static DataFlow$Gradients$ MODULE$;

    static {
        new DataFlow$Gradients$();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> dynamicPartitionGradient(Op op, Seq<OutputLike> seq) {
        Output output = op.inputs()[0];
        Output output2 = op.inputs()[1];
        int longAttribute = (int) op.longAttribute("num_partitions");
        Output shape = Basic$.MODULE$.shape(output2, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4());
        return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{Basic$.MODULE$.reshape(DataFlow$.MODULE$.dynamicStitch(DataFlow$.MODULE$.dynamicPartition(Basic$.MODULE$.reshape(Math$.MODULE$.range(Basic$.MODULE$.constant(org.platanios.tensorflow.api.package$.MODULE$.tensorConvertibleToTensor(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4()), Math$.MODULE$.prod(shape, Math$.MODULE$.prod$default$2(), Math$.MODULE$.prod$default$3(), Math$.MODULE$.prod$default$4()), Math$.MODULE$.range$default$3(), Math$.MODULE$.range$default$4(), Math$.MODULE$.range$default$5()), shape, Basic$.MODULE$.reshape$default$3()), output2, longAttribute, DataFlow$.MODULE$.dynamicPartition$default$4()), (Seq) seq.map(outputLike -> {
            return outputLike.toOutput();
        }, Seq$.MODULE$.canBuildFrom()), DataFlow$.MODULE$.dynamicStitch$default$3()), Basic$.MODULE$.shape(output, Basic$.MODULE$.shape$default$2(), Basic$.MODULE$.shape$default$3(), Basic$.MODULE$.shape$default$4()), Basic$.MODULE$.reshape$default$3()), null}));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> dynamicStitchGradient(Op op, Seq<OutputLike> seq) {
        Output output = ((OutputLike) seq.head()).toOutput();
        int length = op.inputs().length / 2;
        return (Seq) Seq$.MODULE$.fill(length, () -> {
            return null;
        }).$plus$plus(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Output[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(op.inputs())).take(length))).map(output2 -> {
            return (Output) Math$.MODULE$.cast(output2, org.platanios.tensorflow.api.types.package$.MODULE$.INT32(), Math$.MODULE$.cast$default$3(), OutputOps$.MODULE$.outputOps());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Output.class))))).map(output3 -> {
            return Basic$.MODULE$.gather(output, output3, Basic$.MODULE$.gather$default$3(), Basic$.MODULE$.gather$default$4());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Output.class))))), Seq$.MODULE$.canBuildFrom());
    }

    public DataFlow$Gradients$() {
        MODULE$ = this;
        Gradients$Registry$.MODULE$.registerNonDifferentiable("Stack");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("StackPush");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("StackPop");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("StackClose");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("StackV2");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("StackPushV2");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("StackPopV2");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("StackCloseV2");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("GetSessionHandle");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("GetSessionHandleV2");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("GetSessionTensor");
        Gradients$Registry$.MODULE$.registerNonDifferentiable("DeleteSessionTensor");
        Gradients$Registry$.MODULE$.register("DynamicPartition", (op, seq) -> {
            return MODULE$.dynamicPartitionGradient(op, seq);
        });
        Gradients$Registry$.MODULE$.register("DynamicStitch", (op2, seq2) -> {
            return MODULE$.dynamicStitchGradient(op2, seq2);
        });
    }
}
