package org.platanios.tensorflow.api.ops.variables;

import org.platanios.tensorflow.api.core.Indexer;
import org.platanios.tensorflow.api.core.Shape;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Gradients$Registry$;
import org.platanios.tensorflow.api.ops.Op;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.OutputIndexedSlices;
import org.platanios.tensorflow.api.ops.OutputLike;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.SupportedType$;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.runtime.BoxesRunTime;

/* compiled from: Variable.scala */
/* loaded from: input_file:org/platanios/tensorflow/api/ops/variables/Variable$Gradients$.class */
public class Variable$Gradients$ {
    public static Variable$Gradients$ MODULE$;

    static {
        new Variable$Gradients$();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> readGradient(Op op, Seq<OutputLike> seq) {
        return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputLike[]{(OutputLike) seq.head()}));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Seq<OutputLike> gatherGradient(Op op, Seq<OutputLike> seq) {
        Output output;
        Output output2 = op.inputs()[0];
        while (true) {
            output = output2;
            String opType = output.op().opType();
            if (opType != null) {
                if (opType.equals("VarHandleOp")) {
                    break;
                }
                output2 = output.op().inputs()[0];
            } else {
                if ("VarHandleOp" == 0) {
                    break;
                }
                output2 = output.op().inputs()[0];
            }
        }
        Shape shapeAttribute = output.op().shapeAttribute("shape");
        Output output3 = shapeAttribute.toOutput(shapeAttribute.toOutput$default$1(), shapeAttribute.toOutput$default$2());
        Output output4 = op.inputs()[1];
        Output expandDims = Basic$.MODULE$.expandDims(Basic$.MODULE$.size(output4, Basic$.MODULE$.size$default$2(), Basic$.MODULE$.size$default$3(), Basic$.MODULE$.size$default$4()), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.expandDims$default$3());
        Output reshape = Basic$.MODULE$.reshape(((OutputLike) seq.head()).toOutput(), Basic$.MODULE$.concatenate((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Output[]{expandDims, output3.apply(Predef$.MODULE$.wrapRefArray(new Indexer[]{Implicits$.MODULE$.intToIndexerConstruction(1).$colon$colon()}))})), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger(0), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), Basic$.MODULE$.reshape$default$3());
        return Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new OutputIndexedSlices[]{new OutputIndexedSlices(Basic$.MODULE$.reshape(output4, expandDims, Basic$.MODULE$.reshape$default$3()), reshape, output3), null}));
    }

    public Variable$Gradients$() {
        MODULE$ = this;
        Gradients$Registry$.MODULE$.register("ReadVariableOp", (op, seq) -> {
            return MODULE$.readGradient(op, seq);
        });
        Gradients$Registry$.MODULE$.register("ResourceGather", (op2, seq2) -> {
            return MODULE$.gatherGradient(op2, seq2);
        });
    }
}
