/*
 * Decompiled with CFR 0.152.
 */
package org.platanios.tensorflow.api.ops.rnn.cell;

import java.io.Serializable;
import org.platanios.tensorflow.api.core.package$exception$;
import org.platanios.tensorflow.api.implicits.Implicits$;
import org.platanios.tensorflow.api.ops.Basic$;
import org.platanios.tensorflow.api.ops.Clip;
import org.platanios.tensorflow.api.ops.Math$;
import org.platanios.tensorflow.api.ops.NN$;
import org.platanios.tensorflow.api.ops.Op$;
import org.platanios.tensorflow.api.ops.Output;
import org.platanios.tensorflow.api.ops.OutputOps$;
import org.platanios.tensorflow.api.ops.rnn.cell.package;
import org.platanios.tensorflow.api.ops.rnn.cell.package$;
import org.platanios.tensorflow.api.ops.rnn.cell.package$Tuple$;
import org.platanios.tensorflow.api.tensors.TensorConvertible$;
import org.platanios.tensorflow.api.types.SupportedType$;
import org.platanios.tensorflow.jni.InvalidArgumentException;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.runtime.BoxesRunTime;

public final class RNNCell$ {
    public static RNNCell$ MODULE$;

    static {
        new RNNCell$();
    }

    public package.Tuple<Output, Output> basicRNNCell(package.Tuple<Output, Output> input, Output kernel, Output bias, Function1<Output, Output> activation, String name) throws InvalidArgumentException {
        return (package.Tuple)Op$.MODULE$.createWithNameScope(name, Op$.MODULE$.createWithNameScope$default$2(), (Function0 & Serializable & scala.Serializable)() -> {
            Output output = (Output)input.output();
            Output state = (Output)input.state();
            if (output.rank() != 2) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(38).append("Input must be rank-2 (provided rank-").append(output.rank()).append(").").toString());
            }
            if (output.shape().apply(1) == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(42).append("Last axis of input shape (").append(output.shape()).append(") must be known.").toString());
            }
            Output linear = NN$.MODULE$.addBias(Math$.MODULE$.matmul(Basic$.MODULE$.concatenate((Seq<Output>)((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Output[]{output, state}))), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger((int)1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), kernel, Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9()), bias, NN$.MODULE$.addBias$default$3(), NN$.MODULE$.addBias$default$4());
            Output newOutput = (Output)activation.apply((Object)linear);
            return package$Tuple$.MODULE$.apply(newOutput, newOutput);
        });
    }

    public Function1<Output, Output> basicRNNCell$default$4() {
        return (Function1 & Serializable & scala.Serializable)x$1 -> Math$.MODULE$.tanh(x$1, Math$.MODULE$.tanh$default$2(), OutputOps$.MODULE$.outputOps());
    }

    public String basicRNNCell$default$5() {
        return "BasicRNNCell";
    }

    public package.Tuple<Output, Output> gruCell(package.Tuple<Output, Output> input, Output gateKernel, Output gateBias, Output candidateKernel, Output candidateBias, Function1<Output, Output> activation, String name) throws InvalidArgumentException {
        return (package.Tuple)Op$.MODULE$.createWithNameScope(name, Op$.MODULE$.createWithNameScope$default$2(), (Function0 & Serializable & scala.Serializable)() -> {
            Output output = (Output)input.output();
            Output state = (Output)input.state();
            if (output.rank() != 2) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(38).append("Input must be rank-2 (provided rank-").append(output.rank()).append(").").toString());
            }
            if (output.shape().apply(1) == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(42).append("Last axis of input shape (").append(output.shape()).append(") must be known.").toString());
            }
            Output gateIn = NN$.MODULE$.addBias(Math$.MODULE$.matmul(Basic$.MODULE$.concatenate((Seq<Output>)((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Output[]{output, state}))), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger((int)1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), gateKernel, Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9()), gateBias, NN$.MODULE$.addBias$default$3(), NN$.MODULE$.addBias$default$4());
            Seq<Output> value2 = Basic$.MODULE$.splitEvenly(Math$.MODULE$.sigmoid(gateIn, Math$.MODULE$.sigmoid$default$2(), OutputOps$.MODULE$.outputOps()), 2, Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger((int)1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.splitEvenly$default$4());
            Tuple2 tuple2 = new Tuple2(value2.apply(0), value2.apply(1));
            if (tuple2 == null) {
                throw new MatchError((Object)tuple2);
            }
            Output r = (Output)tuple2._1();
            Output u = (Output)tuple2._2();
            Tuple2 tuple22 = new Tuple2((Object)r, (Object)u);
            Tuple2 tuple23 = tuple22;
            Output r2 = (Output)tuple23._1();
            Output u2 = (Output)tuple23._2();
            Output rState = Math$.MODULE$.multiply(r2, state, Math$.MODULE$.multiply$default$3());
            Output c = NN$.MODULE$.addBias(Math$.MODULE$.matmul(Basic$.MODULE$.concatenate((Seq<Output>)((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Output[]{output, rState}))), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger((int)1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), candidateKernel, Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9()), candidateBias, NN$.MODULE$.addBias$default$3(), NN$.MODULE$.addBias$default$4());
            Output newH = Math$.MODULE$.add(Math$.MODULE$.multiply(u2, state, Math$.MODULE$.multiply$default$3()), Math$.MODULE$.multiply(Implicits$.MODULE$.outputConvertibleToMathOps(BoxesRunTime.boxToInteger((int)1), (Function1 & Serializable & scala.Serializable)value -> Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger((int)BoxesRunTime.unboxToInt((Object)value)), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType()))).$minus(u2), c, Math$.MODULE$.multiply$default$3()), Math$.MODULE$.add$default$3());
            return package$Tuple$.MODULE$.apply(newH, newH);
        });
    }

    public Function1<Output, Output> gruCell$default$6() {
        return (Function1 & Serializable & scala.Serializable)x$2 -> Math$.MODULE$.tanh(x$2, Math$.MODULE$.tanh$default$2(), OutputOps$.MODULE$.outputOps());
    }

    public String gruCell$default$7() {
        return "GRUCell";
    }

    public package.Tuple<Output, package.LSTMState> basicLSTMCell(package.Tuple<Output, package.LSTMState> input, Output kernel, Output bias, Function1<Output, Output> activation, float forgetBias, String name) throws InvalidArgumentException {
        return (package.Tuple)Op$.MODULE$.createWithNameScope(name, Op$.MODULE$.createWithNameScope$default$2(), (Function0 & Serializable & scala.Serializable)() -> {
            Output output = (Output)input.output();
            if (output.rank() != 2) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(38).append("Input must be rank-2 (provided rank-").append(output.rank()).append(").").toString());
            }
            if (output.shape().apply(1) == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(42).append("Last axis of input shape (").append(output.shape()).append(") must be known.").toString());
            }
            Output one = Basic$.MODULE$.constant(Implicits$.MODULE$.tensorConvertibleToTensor(BoxesRunTime.boxToInteger((int)1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), org.platanios.tensorflow.api.types.package$.MODULE$.INT32(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4());
            Output lstmMatrix = NN$.MODULE$.addBias(Math$.MODULE$.matmul(Basic$.MODULE$.concatenate((Seq<Output>)((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Output[]{output, ((package.LSTMState)input.state()).m()}))), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger((int)1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), kernel, Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9()), bias, NN$.MODULE$.addBias$default$3(), NN$.MODULE$.addBias$default$4());
            Seq<Output> lstmMatrixBlocks = Basic$.MODULE$.splitEvenly(lstmMatrix, 4, one, Basic$.MODULE$.splitEvenly$default$4());
            Tuple4 tuple4 = new Tuple4(lstmMatrixBlocks.apply(0), lstmMatrixBlocks.apply(1), lstmMatrixBlocks.apply(2), lstmMatrixBlocks.apply(3));
            if (tuple4 == null) {
                throw new MatchError((Object)tuple4);
            }
            Output i = (Output)tuple4._1();
            Output j = (Output)tuple4._2();
            Output f = (Output)tuple4._3();
            Output o = (Output)tuple4._4();
            Tuple4 tuple42 = new Tuple4((Object)i, (Object)j, (Object)f, (Object)o);
            Tuple4 tuple43 = tuple42;
            Output i2 = (Output)tuple43._1();
            Output j2 = (Output)tuple43._2();
            Output f2 = (Output)tuple43._3();
            Output o2 = (Output)tuple43._4();
            Output forgetBiasTensor = Basic$.MODULE$.constant(Implicits$.MODULE$.tensorConvertibleToTensor(BoxesRunTime.boxToFloat((float)forgetBias), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.floatIsSupportedType())), f2.dataType(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4());
            Output c = Math$.MODULE$.add(Math$.MODULE$.multiply(((package.LSTMState)input.state()).c(), Math$.MODULE$.sigmoid(Implicits$.MODULE$.outputToMathOps(f2).$plus(forgetBiasTensor), Math$.MODULE$.sigmoid$default$2(), OutputOps$.MODULE$.outputOps()), Math$.MODULE$.multiply$default$3()), Math$.MODULE$.multiply(Math$.MODULE$.sigmoid(i2, Math$.MODULE$.sigmoid$default$2(), OutputOps$.MODULE$.outputOps()), (Output)activation.apply((Object)j2), Math$.MODULE$.multiply$default$3()), Math$.MODULE$.add$default$3());
            Output m = Math$.MODULE$.multiply((Output)activation.apply((Object)c), Math$.MODULE$.sigmoid(o2, Math$.MODULE$.sigmoid$default$2(), OutputOps$.MODULE$.outputOps()), Math$.MODULE$.multiply$default$3());
            return package$.MODULE$.LSTMTuple(m, new package.LSTMState(c, m));
        });
    }

    public Function1<Output, Output> basicLSTMCell$default$4() {
        return (Function1 & Serializable & scala.Serializable)x$4 -> Math$.MODULE$.tanh(x$4, Math$.MODULE$.tanh$default$2(), OutputOps$.MODULE$.outputOps());
    }

    public float basicLSTMCell$default$5() {
        return 1.0f;
    }

    public String basicLSTMCell$default$6() {
        return "BasicLSTMCell";
    }

    public package.Tuple<Output, package.LSTMState> lstmCell(package.Tuple<Output, package.LSTMState> input, Output kernel, Output bias, float cellClip, Output wfDiag, Output wiDiag, Output woDiag, Output projectionKernel, float projectionClip, Function1<Output, Output> activation, float forgetBias, String name) throws InvalidArgumentException {
        return (package.Tuple)Op$.MODULE$.createWithNameScope(name, Op$.MODULE$.createWithNameScope$default$2(), (Function0 & Serializable & scala.Serializable)() -> {
            Output m;
            Output output = (Output)input.output();
            if (output.rank() != 2) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(38).append("Input must be rank-2 (provided rank-").append(output.rank()).append(").").toString());
            }
            if (output.shape().apply(1) == -1) {
                throw package$exception$.MODULE$.InvalidArgumentException().apply(new StringBuilder(42).append("Last axis of input shape (").append(output.shape()).append(") must be known.").toString());
            }
            Output one = Basic$.MODULE$.constant(Implicits$.MODULE$.tensorConvertibleToTensor(BoxesRunTime.boxToInteger((int)1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), org.platanios.tensorflow.api.types.package$.MODULE$.INT32(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4());
            Output lstmMatrix = NN$.MODULE$.addBias(Math$.MODULE$.matmul(Basic$.MODULE$.concatenate((Seq<Output>)((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Output[]{output, ((package.LSTMState)input.state()).m()}))), Implicits$.MODULE$.tensorConvertibleToOutput(BoxesRunTime.boxToInteger((int)1), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.intIsSupportedType())), Basic$.MODULE$.concatenate$default$3()), kernel, Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9()), bias, NN$.MODULE$.addBias$default$3(), NN$.MODULE$.addBias$default$4());
            Seq<Output> lstmMatrixBlocks = Basic$.MODULE$.splitEvenly(lstmMatrix, 4, one, Basic$.MODULE$.splitEvenly$default$4());
            Tuple4 tuple4 = new Tuple4(lstmMatrixBlocks.apply(0), lstmMatrixBlocks.apply(1), lstmMatrixBlocks.apply(2), lstmMatrixBlocks.apply(3));
            if (tuple4 == null) {
                throw new MatchError((Object)tuple4);
            }
            Output i = (Output)tuple4._1();
            Output j = (Output)tuple4._2();
            Output f = (Output)tuple4._3();
            Output o = (Output)tuple4._4();
            Tuple4 tuple42 = new Tuple4((Object)i, (Object)j, (Object)f, (Object)o);
            Tuple4 tuple43 = tuple42;
            Output i2 = (Output)tuple43._1();
            Output j2 = (Output)tuple43._2();
            Output f2 = (Output)tuple43._3();
            Output o2 = (Output)tuple43._4();
            Output forgetBiasTensor = Basic$.MODULE$.constant(Implicits$.MODULE$.tensorConvertibleToTensor(BoxesRunTime.boxToFloat((float)forgetBias), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.floatIsSupportedType())), f2.dataType(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4());
            Output firstTerm = Implicits$.MODULE$.outputToMathOps(f2).$plus(forgetBiasTensor);
            if (wfDiag != null) {
                firstTerm = Implicits$.MODULE$.outputToMathOps(firstTerm).$plus(Math$.MODULE$.multiply(wfDiag, ((package.LSTMState)input.state()).c(), Math$.MODULE$.multiply$default$3()));
            }
            Output secondTerm = i2;
            if (wiDiag != null) {
                secondTerm = Implicits$.MODULE$.outputToMathOps(secondTerm).$plus(Math$.MODULE$.multiply(wiDiag, ((package.LSTMState)input.state()).c(), Math$.MODULE$.multiply$default$3()));
            }
            Output c = Math$.MODULE$.add(Math$.MODULE$.multiply(((package.LSTMState)input.state()).c(), Math$.MODULE$.sigmoid(firstTerm, Math$.MODULE$.sigmoid$default$2(), OutputOps$.MODULE$.outputOps()), Math$.MODULE$.multiply$default$3()), Math$.MODULE$.multiply(Math$.MODULE$.sigmoid(secondTerm, Math$.MODULE$.sigmoid$default$2(), OutputOps$.MODULE$.outputOps()), (Output)activation.apply((Object)j2), Math$.MODULE$.multiply$default$3()), Math$.MODULE$.add$default$3());
            if (cellClip != (float)-1) {
                Output cellClipTensor = Basic$.MODULE$.constant(Implicits$.MODULE$.tensorConvertibleToTensor(BoxesRunTime.boxToFloat((float)cellClip), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.floatIsSupportedType())), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4());
                Clip.ClipOps qual$1 = Implicits$.MODULE$.outputToClipOps(c);
                Output x$14 = Implicits$.MODULE$.outputToMathOps(cellClipTensor).unary_$minus();
                Output x$15 = cellClipTensor;
                String x$16 = qual$1.clipByValue$default$3();
                c = qual$1.clipByValue(x$14, x$15, x$16);
            }
            Output output2 = m = woDiag != null ? Math$.MODULE$.multiply((Output)activation.apply((Object)c), Math$.MODULE$.sigmoid(Implicits$.MODULE$.outputToMathOps(o2).$plus(Math$.MODULE$.multiply(woDiag, c, Math$.MODULE$.multiply$default$3())), Math$.MODULE$.sigmoid$default$2(), OutputOps$.MODULE$.outputOps()), Math$.MODULE$.multiply$default$3()) : Math$.MODULE$.multiply((Output)activation.apply((Object)c), Math$.MODULE$.sigmoid(o2, Math$.MODULE$.sigmoid$default$2(), OutputOps$.MODULE$.outputOps()), Math$.MODULE$.multiply$default$3());
            if (projectionKernel != null) {
                m = Math$.MODULE$.matmul(m, projectionKernel, Math$.MODULE$.matmul$default$3(), Math$.MODULE$.matmul$default$4(), Math$.MODULE$.matmul$default$5(), Math$.MODULE$.matmul$default$6(), Math$.MODULE$.matmul$default$7(), Math$.MODULE$.matmul$default$8(), Math$.MODULE$.matmul$default$9());
                if (projectionClip != (float)-1) {
                    Output projectionClipTensor = Basic$.MODULE$.constant(Implicits$.MODULE$.tensorConvertibleToTensor(BoxesRunTime.boxToFloat((float)projectionClip), TensorConvertible$.MODULE$.supportedTypeTensorConvertible(SupportedType$.MODULE$.floatIsSupportedType())), Basic$.MODULE$.constant$default$2(), Basic$.MODULE$.constant$default$3(), Basic$.MODULE$.constant$default$4());
                    Clip.ClipOps qual$2 = Implicits$.MODULE$.outputToClipOps(m);
                    Output x$17 = Implicits$.MODULE$.outputToMathOps(projectionClipTensor).unary_$minus();
                    Output x$18 = projectionClipTensor;
                    String x$19 = qual$2.clipByValue$default$3();
                    m = qual$2.clipByValue(x$17, x$18, x$19);
                }
            }
            return package$.MODULE$.LSTMTuple(m, new package.LSTMState(c, m));
        });
    }

    public float lstmCell$default$4() {
        return -1.0f;
    }

    public Output lstmCell$default$5() {
        return null;
    }

    public Output lstmCell$default$6() {
        return null;
    }

    public Output lstmCell$default$7() {
        return null;
    }

    public Output lstmCell$default$8() {
        return null;
    }

    public float lstmCell$default$9() {
        return -1.0f;
    }

    public Function1<Output, Output> lstmCell$default$10() {
        return (Function1 & Serializable & scala.Serializable)x$6 -> Math$.MODULE$.tanh(x$6, Math$.MODULE$.tanh$default$2(), OutputOps$.MODULE$.outputOps());
    }

    public float lstmCell$default$11() {
        return 1.0f;
    }

    public String lstmCell$default$12() {
        return "LSTMCell";
    }

    private RNNCell$() {
        MODULE$ = this;
    }
}

