package scorch.sandbox.rnn;

import botkop.numsca.package$;
import scala.App;
import scala.Array$;
import scala.Function0;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Map;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.ListBuffer;
import scala.io.Codec$;
import scala.io.Source$;
import scala.math.Ordering$Char$;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction0;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;
import scala.util.Random$;
import scorch.autograd.Variable;
import scorch.autograd.Variable$;
import scorch.optim.Optimizer;
import scorch.sandbox.rnn.DinosaurIslandCharRnn;

/* compiled from: DinosaurIslandCharRnn.scala */
/* loaded from: input_file:scorch/sandbox/rnn/DinosaurIslandCharRnn$.class */
public final class DinosaurIslandCharRnn$ implements App {
    public static DinosaurIslandCharRnn$ MODULE$;
    private final List<String> examples;
    private final char[] chars;
    private final int vocabSize;
    private final Map<Object, Object> charToIx;
    private final int EosIndex;
    private final int BosIndex;
    private final long executionStart;
    private String[] scala$App$$_args;
    private final ListBuffer<Function0<BoxedUnit>> scala$App$$initCode;

    static {
        new DinosaurIslandCharRnn$();
    }

    public String[] args() {
        return App.args$(this);
    }

    public void delayedInit(Function0<BoxedUnit> function0) {
        App.delayedInit$(this, function0);
    }

    public void main(String[] strArr) {
        App.main$(this, strArr);
    }

    public long executionStart() {
        return this.executionStart;
    }

    public String[] scala$App$$_args() {
        return this.scala$App$$_args;
    }

    public void scala$App$$_args_$eq(String[] strArr) {
        this.scala$App$$_args = strArr;
    }

    public ListBuffer<Function0<BoxedUnit>> scala$App$$initCode() {
        return this.scala$App$$initCode;
    }

    public void scala$App$_setter_$executionStart_$eq(long j) {
        this.executionStart = j;
    }

    public final void scala$App$_setter_$scala$App$$initCode_$eq(ListBuffer<Function0<BoxedUnit>> listBuffer) {
        this.scala$App$$initCode = listBuffer;
    }

    public List<String> examples() {
        return this.examples;
    }

    public char[] chars() {
        return this.chars;
    }

    public int vocabSize() {
        return this.vocabSize;
    }

    public Map<Object, Object> charToIx() {
        return this.charToIx;
    }

    public int EosIndex() {
        return this.EosIndex;
    }

    public int BosIndex() {
        return this.BosIndex;
    }

    public void model(String str, List<String> list, Map<Object, Object> map, int i, int i2, int i3, int i4, int i5, int i6) {
        Serializable apply;
        Tuple2.mcII.sp spVar = new Tuple2.mcII.sp(i4, i4);
        if (spVar == null) {
            throw new MatchError(spVar);
        }
        Tuple2.mcII.sp spVar2 = new Tuple2.mcII.sp(spVar._1$mcI$sp(), spVar._2$mcI$sp());
        int _1$mcI$sp = spVar2._1$mcI$sp();
        int _2$mcI$sp = spVar2._2$mcI$sp();
        if ("rnn".equals(str)) {
            apply = DinosaurIslandCharRnn$RnnCell$.MODULE$.apply(i2, _1$mcI$sp, _2$mcI$sp);
        } else if ("lstm".equals(str)) {
            apply = DinosaurIslandCharRnn$LstmCell$.MODULE$.apply(i2, _1$mcI$sp, _2$mcI$sp);
        } else {
            if (!"gru".equals(str)) {
                throw new Error(new StringBuilder(18).append("unknown cell type ").append(str).toString());
            }
            apply = DinosaurIslandCharRnn$GruCell$.MODULE$.apply(i2, _1$mcI$sp, _2$mcI$sp);
        }
        Serializable serializable = apply;
        DinosaurIslandCharRnn.ClippingSGD clippingSGD = new DinosaurIslandCharRnn.ClippingSGD(serializable.parameters(), 0.05d, 5.0d);
        DinosaurIslandCharRnn.Sampler sampler = new DinosaurIslandCharRnn.Sampler(serializable, map, EosIndex(), i5, i2);
        DoubleRef create = DoubleRef.create(0.0d);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), i).foreach$mVc$sp(i7 -> {
            List<Object> list2 = (List) ((TraversableOnce) new StringOps(Predef$.MODULE$.augmentString((String) list.apply(i7 % list.length()))).map(map, Predef$.MODULE$.fallbackStringCanBuildFrom())).toList().$plus$colon(BoxesRunTime.boxToInteger(MODULE$.BosIndex()), List$.MODULE$.canBuildFrom());
            create.elem += MODULE$.optimize(list2, (List) ((SeqLike) list2.tail()).$colon$plus(BoxesRunTime.boxToInteger(MODULE$.EosIndex()), List$.MODULE$.canBuildFrom()), serializable, clippingSGD);
            if (i7 % i6 == 0) {
                Predef$.MODULE$.println(new StringBuilder(19).append("Iteration: ").append(i7).append(", Loss: ").append(create.elem / i6).toString());
                RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), i3).foreach$mVc$sp(i7 -> {
                    Predef$.MODULE$.print(sampler.sample());
                });
                Predef$.MODULE$.println();
                create.elem = 0.0d;
            }
        });
    }

    public int model$default$4() {
        return 35000;
    }

    public int model$default$5() {
        return 50;
    }

    public int model$default$6() {
        return 7;
    }

    public int model$default$8() {
        return 50;
    }

    public int model$default$9() {
        return 1000;
    }

    public Variable rnnLoss(Seq<Variable> seq, Seq<Object> seq2) {
        return new DinosaurIslandCharRnn.CrossEntropyLoss(seq, seq2).forward();
    }

    public List<Variable> rnnForward(List<Object> list, DinosaurIslandCharRnn.BaseRnnCell baseRnnCell, int i) {
        return (List) ((Tuple2) list.foldLeft(new Tuple2(List$.MODULE$.empty(), baseRnnCell.initialTrackingStates()), (tuple2, obj) -> {
            return $anonfun$rnnForward$1(baseRnnCell, i, tuple2, BoxesRunTime.unboxToInt(obj));
        }))._1();
    }

    public double optimize(List<Object> list, List<Object> list2, DinosaurIslandCharRnn.BaseRnnCell baseRnnCell, Optimizer optimizer) {
        optimizer.zeroGrad();
        Variable rnnLoss = rnnLoss(rnnForward(list, baseRnnCell, vocabSize()), list2);
        rnnLoss.backward();
        optimizer.step();
        return rnnLoss.data().squeeze();
    }

    public static final /* synthetic */ Tuple2 $anonfun$rnnForward$1(DinosaurIslandCharRnn.BaseRnnCell baseRnnCell, int i, Tuple2 tuple2, int i2) {
        Tuple2 tuple22 = new Tuple2(tuple2, BoxesRunTime.boxToInteger(i2));
        if (tuple22 != null) {
            Tuple2 tuple23 = (Tuple2) tuple22._1();
            int _2$mcI$sp = tuple22._2$mcI$sp();
            if (tuple23 != null) {
                List list = (List) tuple23._1();
                Seq seq = (Seq) tuple23._2();
                Variable variable = new Variable(package$.MODULE$.zeros(Predef$.MODULE$.wrapIntArray(new int[]{i, 1})), Variable$.MODULE$.apply$default$2(), Variable$.MODULE$.apply$default$3());
                if (_2$mcI$sp != MODULE$.BosIndex()) {
                    variable.data().apply(Predef$.MODULE$.wrapIntArray(new int[]{_2$mcI$sp, 0})).$colon$eq(1.0d);
                }
                Seq<Variable> apply = baseRnnCell.apply((Seq) seq.$plus$colon(variable, Seq$.MODULE$.canBuildFrom()));
                Tuple2 tuple24 = new Tuple2(apply.head(), apply.tail());
                if (tuple24 == null) {
                    throw new MatchError(tuple24);
                }
                Tuple2 tuple25 = new Tuple2((Variable) tuple24._1(), (Seq) tuple24._2());
                return new Tuple2(list.$colon$plus((Variable) tuple25._1(), List$.MODULE$.canBuildFrom()), (Seq) tuple25._2());
            }
        }
        throw new MatchError(tuple22);
    }

    public final void delayedEndpoint$scorch$sandbox$rnn$DinosaurIslandCharRnn$1() {
        package$.MODULE$.rand().setSeed(231);
        Random$.MODULE$.setSeed(231L);
        this.examples = Random$.MODULE$.shuffle(Source$.MODULE$.fromFile("src/test/resources/dinos.txt", Codec$.MODULE$.fallbackSystemCodec()).getLines().map(str -> {
            return str.toLowerCase();
        }).toList(), List$.MODULE$.canBuildFrom());
        this.chars = (char[]) new ArrayOps.ofChar(Predef$.MODULE$.charArrayOps((char[]) new ArrayOps.ofChar(Predef$.MODULE$.charArrayOps((char[]) new ArrayOps.ofChar(Predef$.MODULE$.charArrayOps(examples().mkString().toCharArray())).distinct())).sorted(Ordering$Char$.MODULE$))).$colon$plus(BoxesRunTime.boxToCharacter('\n'), ClassTag$.MODULE$.Char());
        this.vocabSize = chars().length;
        this.charToIx = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofChar(Predef$.MODULE$.charArrayOps(chars())).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).toMap(Predef$.MODULE$.$conforms());
        this.EosIndex = BoxesRunTime.unboxToInt(charToIx().apply(BoxesRunTime.boxToCharacter('\n')));
        this.BosIndex = -1;
        model("gru", examples(), charToIx(), model$default$4(), 40, model$default$6(), vocabSize(), 60, 100);
    }

    private DinosaurIslandCharRnn$() {
        MODULE$ = this;
        App.$init$(this);
        delayedInit(new AbstractFunction0(this) { // from class: scorch.sandbox.rnn.DinosaurIslandCharRnn$delayedInit$body
            private final DinosaurIslandCharRnn$ $outer;

            public final Object apply() {
                this.$outer.delayedEndpoint$scorch$sandbox$rnn$DinosaurIslandCharRnn$1();
                return BoxedUnit.UNIT;
            }

            {
                if (this == null) {
                    throw null;
                }
                this.$outer = this;
            }
        });
    }
}
