package scorch.nn.rnn;

import botkop.numsca.package$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Product;
import scala.Serializable;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scorch.autograd.Variable;
import scorch.nn.SeqModule;

/* compiled from: RnnBase.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ed\u0001B\u0001\u0003\u0001&\u0011qA\u00158o\u0005\u0006\u001cXM\u0003\u0002\u0004\t\u0005\u0019!O\u001c8\u000b\u0005\u00151\u0011A\u00018o\u0015\u00059\u0011AB:d_J\u001c\u0007n\u0001\u0001\u0014\t\u0001Qa\u0002\u0006\t\u0003\u00171i\u0011\u0001B\u0005\u0003\u001b\u0011\u0011\u0011bU3r\u001b>$W\u000f\\3\u0011\u0005=\u0011R\"\u0001\t\u000b\u0003E\tQa]2bY\u0006L!a\u0005\t\u0003\u000fA\u0013x\u000eZ;diB\u0011q\"F\u0005\u0003-A\u0011AbU3sS\u0006d\u0017N_1cY\u0016D\u0001\u0002\u0007\u0001\u0003\u0016\u0004%\t!G\u0001\u0005G\u0016dG.F\u0001\u001b!\tYB$D\u0001\u0003\u0013\ti\"AA\u0006S]:\u001cU\r\u001c7CCN,\u0007\u0002C\u0010\u0001\u0005#\u0005\u000b\u0011\u0002\u000e\u0002\u000b\r,G\u000e\u001c\u0011\t\u000b\u0005\u0002A\u0011\u0001\u0012\u0002\rqJg.\u001b;?)\t\u0019C\u0005\u0005\u0002\u001c\u0001!)\u0001\u0004\ta\u00015!)a\u0005\u0001C!O\u00059am\u001c:xCJ$GC\u0001\u0015;!\rI\u0013\u0007\u000e\b\u0003U=r!a\u000b\u0018\u000e\u00031R!!\f\u0005\u0002\rq\u0012xn\u001c;?\u0013\u0005\t\u0012B\u0001\u0019\u0011\u0003\u001d\u0001\u0018mY6bO\u0016L!AM\u001a\u0003\u0007M+\u0017O\u0003\u00021!A\u0011Q\u0007O\u0007\u0002m)\u0011qGB\u0001\tCV$xn\u001a:bI&\u0011\u0011H\u000e\u0002\t-\u0006\u0014\u0018.\u00192mK\")1(\na\u0001Q\u0005\u0011\u0001p\u001d\u0005\u0006{\u0001!\tAP\u0001\u0007g\u0006l\u0007\u000f\\3\u0015\u000b}\u001a\u0005J\u0013'\u0011\u0007%\n\u0004\t\u0005\u0002\u0010\u0003&\u0011!\t\u0005\u0002\u0004\u0013:$\b\"\u0002#=\u0001\u0004)\u0015AB3oG>$W\r\u0005\u0003\u0010\r\u0002#\u0014BA$\u0011\u0005%1UO\\2uS>t\u0017\u0007C\u0003Jy\u0001\u0007\u0001)\u0001\u0005c_NLe\u000eZ3y\u0011\u0015YE\b1\u0001A\u0003!)wn]%oI\u0016D\b\"B'=\u0001\u0004\u0001\u0015aD7bqN+g\u000e^3oG\u0016\u001c\u0016N_3\t\u000f=\u0003\u0011\u0011!C\u0001!\u0006!1m\u001c9z)\t\u0019\u0013\u000bC\u0004\u0019\u001dB\u0005\t\u0019\u0001\u000e\t\u000fM\u0003\u0011\u0013!C\u0001)\u0006q1m\u001c9zI\u0011,g-Y;mi\u0012\nT#A++\u0005i16&A,\u0011\u0005akV\"A-\u000b\u0005i[\u0016!C;oG\",7m[3e\u0015\ta\u0006#\u0001\u0006b]:|G/\u0019;j_:L!AX-\u0003#Ut7\r[3dW\u0016$g+\u0019:jC:\u001cW\rC\u0004a\u0001\u0005\u0005I\u0011I1\u0002\u001bA\u0014x\u000eZ;diB\u0013XMZ5y+\u0005\u0011\u0007CA2i\u001b\u0005!'BA3g\u0003\u0011a\u0017M\\4\u000b\u0003\u001d\fAA[1wC&\u0011\u0011\u000e\u001a\u0002\u0007'R\u0014\u0018N\\4\t\u000f-\u0004\u0011\u0011!C\u0001Y\u0006a\u0001O]8ek\u000e$\u0018I]5usV\t\u0001\tC\u0004o\u0001\u0005\u0005I\u0011A8\u0002\u001dA\u0014x\u000eZ;di\u0016cW-\\3oiR\u0011\u0001o\u001d\t\u0003\u001fEL!A\u001d\t\u0003\u0007\u0005s\u0017\u0010C\u0004u[\u0006\u0005\t\u0019\u0001!\u0002\u0007a$\u0013\u0007C\u0004w\u0001\u0005\u0005I\u0011I<\u0002\u001fA\u0014x\u000eZ;di&#XM]1u_J,\u0012\u0001\u001f\t\u0004sr\u0004X\"\u0001>\u000b\u0005m\u0004\u0012AC2pY2,7\r^5p]&\u0011QP\u001f\u0002\t\u0013R,'/\u0019;pe\"Aq\u0010AA\u0001\n\u0003\t\t!\u0001\u0005dC:,\u0015/^1m)\u0011\t\u0019!!\u0003\u0011\u0007=\t)!C\u0002\u0002\bA\u0011qAQ8pY\u0016\fg\u000eC\u0004u}\u0006\u0005\t\u0019\u00019\t\u0013\u00055\u0001!!A\u0005B\u0005=\u0011\u0001\u00035bg\"\u001cu\u000eZ3\u0015\u0003\u0001C\u0011\"a\u0005\u0001\u0003\u0003%\t%!\u0006\u0002\u0011Q|7\u000b\u001e:j]\u001e$\u0012A\u0019\u0005\n\u00033\u0001\u0011\u0011!C!\u00037\ta!Z9vC2\u001cH\u0003BA\u0002\u0003;A\u0001\u0002^A\f\u0003\u0003\u0005\r\u0001]\u0004\b\u0003C\u0011\u0001\u0012AA\u0012\u0003\u001d\u0011fN\u001c\"bg\u0016\u00042aGA\u0013\r\u0019\t!\u0001#\u0001\u0002(M)\u0011QEA\u0015)A\u0019q\"a\u000b\n\u0007\u00055\u0002C\u0001\u0004B]f\u0014VM\u001a\u0005\bC\u0005\u0015B\u0011AA\u0019)\t\t\u0019\u0003\u0003\u0005\u00026\u0005\u0015B\u0011AA\u001c\u0003\u0015\t\u0007\u000f\u001d7z)%\u0019\u0013\u0011HA&\u0003\u001f\n\u0019\u0006\u0003\u0005\u0002<\u0005M\u0002\u0019AA\u001f\u0003\u001d\u0011hN\u001c+za\u0016\u0004B!a\u0010\u0002H9!\u0011\u0011IA\"!\tY\u0003#C\u0002\u0002FA\ta\u0001\u0015:fI\u00164\u0017bA5\u0002J)\u0019\u0011Q\t\t\t\u000f\u00055\u00131\u0007a\u0001\u0001\u0006\u0011a.\u0019\u0005\b\u0003#\n\u0019\u00041\u0001A\u0003\tq\u0007\u0010C\u0004\u0002V\u0005M\u0002\u0019\u0001!\u0002\u00059L\bBCA\u001b\u0003K\t\t\u0011\"!\u0002ZQ\u00191%a\u0017\t\ra\t9\u00061\u0001\u001b\u0011)\ty&!\n\u0002\u0002\u0013\u0005\u0015\u0011M\u0001\bk:\f\u0007\u000f\u001d7z)\u0011\t\u0019'!\u001b\u0011\t=\t)GG\u0005\u0004\u0003O\u0002\"AB(qi&|g\u000eC\u0005\u0002l\u0005u\u0013\u0011!a\u0001G\u0005\u0019\u0001\u0010\n\u0019\t\u0015\u0005=\u0014QEA\u0001\n\u0013\t\t(A\u0006sK\u0006$'+Z:pYZ,GCAA:!\r\u0019\u0017QO\u0005\u0004\u0003o\"'AB(cU\u0016\u001cG\u000f")
/* loaded from: input_file:scorch/nn/rnn/RnnBase.class */
public class RnnBase extends SeqModule implements Product, Serializable {
    private final RnnCellBase cell;

    public static Option<RnnCellBase> unapply(RnnBase rnnBase) {
        return RnnBase$.MODULE$.unapply(rnnBase);
    }

    public RnnCellBase cell() {
        return this.cell;
    }

    @Override // scorch.nn.SeqModule
    public Seq<Variable> forward(Seq<Variable> seq) {
        return (Seq) ((Tuple2) seq.foldLeft(new Tuple2(List$.MODULE$.empty(), cell().initialTrackingStates()), (tuple2, variable) -> {
            Tuple2 tuple2 = new Tuple2(tuple2, variable);
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2) tuple2._1();
                Variable variable = (Variable) tuple2._2();
                if (tuple22 != null) {
                    List list = (List) tuple22._1();
                    Seq<Variable> apply = this.cell().apply((Seq) ((Seq) tuple22._2()).$plus$colon(variable, Seq$.MODULE$.canBuildFrom()));
                    Tuple2 tuple23 = new Tuple2(apply.head(), apply.tail());
                    if (tuple23 == null) {
                        throw new MatchError(tuple23);
                    }
                    Tuple2 tuple24 = new Tuple2((Variable) tuple23._1(), (Seq) tuple23._2());
                    Variable variable2 = (Variable) tuple24._1();
                    return new Tuple2(list.$colon$plus(variable2, List$.MODULE$.canBuildFrom()), (Seq) tuple24._2());
                }
            }
            throw new MatchError(tuple2);
        }))._1();
    }

    public Seq<Object> sample(Function1<Object, Variable> function1, int i, int i2, int i3) {
        return generateSequence$1(generateSequence$default$1$1(), generateSequence$default$2$1(function1, i), generateSequence$default$3$1(), generateSequence$default$4$1(), function1, i2, i3);
    }

    public RnnBase copy(RnnCellBase rnnCellBase) {
        return new RnnBase(rnnCellBase);
    }

    public RnnCellBase copy$default$1() {
        return cell();
    }

    public String productPrefix() {
        return "RnnBase";
    }

    public int productArity() {
        return 1;
    }

    public Object productElement(int i) {
        switch (i) {
            case 0:
                return cell();
            default:
                throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
        }
    }

    public Iterator<Object> productIterator() {
        return ScalaRunTime$.MODULE$.typedProductIterator(this);
    }

    public boolean canEqual(Object obj) {
        return obj instanceof RnnBase;
    }

    public int hashCode() {
        return ScalaRunTime$.MODULE$._hashCode(this);
    }

    public String toString() {
        return ScalaRunTime$.MODULE$._toString(this);
    }

    public boolean equals(Object obj) {
        boolean z;
        if (this != obj) {
            if (obj instanceof RnnBase) {
                RnnBase rnnBase = (RnnBase) obj;
                RnnCellBase cell = cell();
                RnnCellBase cell2 = rnnBase.cell();
                if (cell != null ? cell.equals(cell2) : cell2 == null) {
                    if (rnnBase.canEqual(this)) {
                        z = true;
                        if (!z) {
                        }
                    }
                }
                z = false;
                if (!z) {
                }
            }
            return false;
        }
        return true;
    }

    private final Tuple3 generateToken$1(Variable variable, Seq seq, Function1 function1) {
        Seq<Variable> apply = cell().apply((Seq) seq.$plus$colon(variable, Seq$.MODULE$.canBuildFrom()));
        Tuple2 tuple2 = new Tuple2(apply.head(), apply.tail());
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Variable) tuple2._1(), (Seq) tuple2._2());
        Variable variable2 = (Variable) tuple22._1();
        Seq seq2 = (Seq) tuple22._2();
        int squeeze = (int) package$.MODULE$.choice(package$.MODULE$.arange(BoxesRunTime.unboxToInt(variable2.shape().head())), variable2.data(), package$.MODULE$.choice$default$3()).squeeze();
        return new Tuple3((Variable) function1.apply(BoxesRunTime.boxToInteger(squeeze)), BoxesRunTime.boxToInteger(squeeze), seq2);
    }

    private final List generateSequence$1(int i, Variable variable, Seq seq, List list, Function1 function1, int i2, int i3) {
        while (!list.lastOption().contains(BoxesRunTime.boxToInteger(i2))) {
            if (i >= i3) {
                return (List) list.$colon$plus(BoxesRunTime.boxToInteger(i2), List$.MODULE$.canBuildFrom());
            }
            Tuple3 generateToken$1 = generateToken$1(variable, seq, function1);
            if (generateToken$1 == null) {
                throw new MatchError(generateToken$1);
            }
            Variable variable2 = (Variable) generateToken$1._1();
            int unboxToInt = BoxesRunTime.unboxToInt(generateToken$1._2());
            Tuple3 tuple3 = new Tuple3(variable2, BoxesRunTime.boxToInteger(unboxToInt), (Seq) generateToken$1._3());
            Variable variable3 = (Variable) tuple3._1();
            int unboxToInt2 = BoxesRunTime.unboxToInt(tuple3._2());
            Seq seq2 = (Seq) tuple3._3();
            list = (List) list.$colon$plus(BoxesRunTime.boxToInteger(unboxToInt2), List$.MODULE$.canBuildFrom());
            seq = seq2;
            variable = variable3;
            i++;
        }
        return list;
    }

    private static final int generateSequence$default$1$1() {
        return 1;
    }

    private static final Variable generateSequence$default$2$1(Function1 function1, int i) {
        return (Variable) function1.apply(BoxesRunTime.boxToInteger(i));
    }

    private final Seq generateSequence$default$3$1() {
        return cell().initialTrackingStates();
    }

    private static final List generateSequence$default$4$1() {
        return List$.MODULE$.empty();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public RnnBase(RnnCellBase rnnCellBase) {
        super(rnnCellBase.parameters());
        this.cell = rnnCellBase;
        Product.$init$(this);
    }
}
