package scorch.optim;

import botkop.numsca.Tensor;
import botkop.numsca.package$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Product;
import scala.Serializable;
import scala.Tuple2;
import scala.Tuple5;
import scala.collection.IterableLike;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Statics;
import scorch.autograd.Variable;

/* compiled from: Adam.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ug\u0001B\u0001\u0003\u0001\u001e\u0011A!\u00113b[*\u00111\u0001B\u0001\u0006_B$\u0018.\u001c\u0006\u0002\u000b\u000511oY8sG\"\u001c\u0001a\u0005\u0003\u0001\u00111\u0011\u0002CA\u0005\u000b\u001b\u0005\u0011\u0011BA\u0006\u0003\u0005%y\u0005\u000f^5nSj,'\u000f\u0005\u0002\u000e!5\taBC\u0001\u0010\u0003\u0015\u00198-\u00197b\u0013\t\tbBA\u0004Qe>$Wo\u0019;\u0011\u00055\u0019\u0012B\u0001\u000b\u000f\u00051\u0019VM]5bY&T\u0018M\u00197f\u0011!1\u0002A!f\u0001\n\u00039\u0012A\u00039be\u0006lW\r^3sgV\t\u0001\u0004E\u0002\u001aC\u0011r!AG\u0010\u000f\u0005mqR\"\u0001\u000f\u000b\u0005u1\u0011A\u0002\u001fs_>$h(C\u0001\u0010\u0013\t\u0001c\"A\u0004qC\u000e\\\u0017mZ3\n\u0005\t\u001a#aA*fc*\u0011\u0001E\u0004\t\u0003K!j\u0011A\n\u0006\u0003O\u0011\t\u0001\"Y;u_\u001e\u0014\u0018\rZ\u0005\u0003S\u0019\u0012\u0001BV1sS\u0006\u0014G.\u001a\u0005\tW\u0001\u0011\t\u0012)A\u00051\u0005Y\u0001/\u0019:b[\u0016$XM]:!\u0011!i\u0003A!f\u0001\n\u0003q\u0013A\u00017s+\u0005y\u0003CA\u00071\u0013\t\tdB\u0001\u0004E_V\u0014G.\u001a\u0005\tg\u0001\u0011\t\u0012)A\u0005_\u0005\u0019AN\u001d\u0011\t\u0011U\u0002!Q3A\u0005\u00029\nQAY3uCFB\u0001b\u000e\u0001\u0003\u0012\u0003\u0006IaL\u0001\u0007E\u0016$\u0018-\r\u0011\t\u0011e\u0002!Q3A\u0005\u00029\nQAY3uCJB\u0001b\u000f\u0001\u0003\u0012\u0003\u0006IaL\u0001\u0007E\u0016$\u0018M\r\u0011\t\u0011u\u0002!Q3A\u0005\u00029\nq!\u001a9tS2|g\u000e\u0003\u0005@\u0001\tE\t\u0015!\u00030\u0003!)\u0007o]5m_:\u0004\u0003\"B!\u0001\t\u0003\u0011\u0015A\u0002\u001fj]&$h\b\u0006\u0004D\t\u00163u\t\u0013\t\u0003\u0013\u0001AQA\u0006!A\u0002aAQ!\f!A\u0002=Bq!\u000e!\u0011\u0002\u0003\u0007q\u0006C\u0004:\u0001B\u0005\t\u0019A\u0018\t\u000fu\u0002\u0005\u0013!a\u0001_!9!\n\u0001b\u0001\n\u0003Y\u0015AA7t+\u0005a\u0005cA\r\"\u001bB\u0011ajU\u0007\u0002\u001f*\u0011\u0001+U\u0001\u0007]Vl7oY1\u000b\u0003I\u000baAY8uW>\u0004\u0018B\u0001+P\u0005\u0019!VM\\:pe\"1a\u000b\u0001Q\u0001\n1\u000b1!\\:!\u0011\u001dA\u0006A1A\u0005\u0002-\u000b!A^:\t\ri\u0003\u0001\u0015!\u0003M\u0003\r18\u000f\t\u0005\b9\u0002\u0001\r\u0011\"\u0001^\u0003\u0005!X#\u00010\u0011\u00055y\u0016B\u00011\u000f\u0005\rIe\u000e\u001e\u0005\bE\u0002\u0001\r\u0011\"\u0001d\u0003\u0015!x\fJ3r)\t!w\r\u0005\u0002\u000eK&\u0011aM\u0004\u0002\u0005+:LG\u000fC\u0004iC\u0006\u0005\t\u0019\u00010\u0002\u0007a$\u0013\u0007\u0003\u0004k\u0001\u0001\u0006KAX\u0001\u0003i\u0002BQ\u0001\u001c\u0001\u0005B5\fAa\u001d;faR\tA\rC\u0004p\u0001\u0005\u0005I\u0011\u00019\u0002\t\r|\u0007/\u001f\u000b\u0007\u0007F\u00148\u000f^;\t\u000fYq\u0007\u0013!a\u00011!9QF\u001cI\u0001\u0002\u0004y\u0003bB\u001bo!\u0003\u0005\ra\f\u0005\bs9\u0004\n\u00111\u00010\u0011\u001did\u000e%AA\u0002=Bqa\u001e\u0001\u0012\u0002\u0013\u0005\u00010\u0001\bd_BLH\u0005Z3gCVdG\u000fJ\u0019\u0016\u0003eT#\u0001\u0007>,\u0003m\u00042\u0001`A\u0002\u001b\u0005i(B\u0001@��\u0003%)hn\u00195fG.,GMC\u0002\u0002\u00029\t!\"\u00198o_R\fG/[8o\u0013\r\t)! \u0002\u0012k:\u001c\u0007.Z2lK\u00124\u0016M]5b]\u000e,\u0007\"CA\u0005\u0001E\u0005I\u0011AA\u0006\u00039\u0019w\u000e]=%I\u00164\u0017-\u001e7uII*\"!!\u0004+\u0005=R\b\"CA\t\u0001E\u0005I\u0011AA\u0006\u00039\u0019w\u000e]=%I\u00164\u0017-\u001e7uIMB\u0011\"!\u0006\u0001#\u0003%\t!a\u0003\u0002\u001d\r|\u0007/\u001f\u0013eK\u001a\fW\u000f\u001c;%i!I\u0011\u0011\u0004\u0001\u0012\u0002\u0013\u0005\u00111B\u0001\u000fG>\u0004\u0018\u0010\n3fM\u0006,H\u000e\u001e\u00136\u0011%\ti\u0002AA\u0001\n\u0003\ny\"A\u0007qe>$Wo\u0019;Qe\u00164\u0017\u000e_\u000b\u0003\u0003C\u0001B!a\t\u0002.5\u0011\u0011Q\u0005\u0006\u0005\u0003O\tI#\u0001\u0003mC:<'BAA\u0016\u0003\u0011Q\u0017M^1\n\t\u0005=\u0012Q\u0005\u0002\u0007'R\u0014\u0018N\\4\t\u0011\u0005M\u0002!!A\u0005\u0002u\u000bA\u0002\u001d:pIV\u001cG/\u0011:jifD\u0011\"a\u000e\u0001\u0003\u0003%\t!!\u000f\u0002\u001dA\u0014x\u000eZ;di\u0016cW-\\3oiR!\u00111HA!!\ri\u0011QH\u0005\u0004\u0003\u007fq!aA!os\"A\u0001.!\u000e\u0002\u0002\u0003\u0007a\fC\u0005\u0002F\u0001\t\t\u0011\"\u0011\u0002H\u0005y\u0001O]8ek\u000e$\u0018\n^3sCR|'/\u0006\u0002\u0002JA1\u00111JA)\u0003wi!!!\u0014\u000b\u0007\u0005=c\"\u0001\u0006d_2dWm\u0019;j_:LA!a\u0015\u0002N\tA\u0011\n^3sCR|'\u000fC\u0005\u0002X\u0001\t\t\u0011\"\u0001\u0002Z\u0005A1-\u00198FcV\fG\u000e\u0006\u0003\u0002\\\u0005\u0005\u0004cA\u0007\u0002^%\u0019\u0011q\f\b\u0003\u000f\t{w\u000e\\3b]\"I\u0001.!\u0016\u0002\u0002\u0003\u0007\u00111\b\u0005\n\u0003K\u0002\u0011\u0011!C!\u0003O\n\u0001\u0002[1tQ\u000e{G-\u001a\u000b\u0002=\"I\u00111\u000e\u0001\u0002\u0002\u0013\u0005\u0013QN\u0001\ti>\u001cFO]5oOR\u0011\u0011\u0011\u0005\u0005\n\u0003c\u0002\u0011\u0011!C!\u0003g\na!Z9vC2\u001cH\u0003BA.\u0003kB\u0011\u0002[A8\u0003\u0003\u0005\r!a\u000f\b\u0013\u0005e$!!A\t\u0002\u0005m\u0014\u0001B!eC6\u00042!CA?\r!\t!!!A\t\u0002\u0005}4#BA?\u0003\u0003\u0013\u0002CCAB\u0003\u0013CrfL\u00180\u00076\u0011\u0011Q\u0011\u0006\u0004\u0003\u000fs\u0011a\u0002:v]RLW.Z\u0005\u0005\u0003\u0017\u000b)IA\tBEN$(/Y2u\rVt7\r^5p]VBq!QA?\t\u0003\ty\t\u0006\u0002\u0002|!Q\u00111NA?\u0003\u0003%)%!\u001c\t\u0015\u0005U\u0015QPA\u0001\n\u0003\u000b9*A\u0003baBd\u0017\u0010F\u0006D\u00033\u000bY*!(\u0002 \u0006\u0005\u0006B\u0002\f\u0002\u0014\u0002\u0007\u0001\u0004\u0003\u0004.\u0003'\u0003\ra\f\u0005\tk\u0005M\u0005\u0013!a\u0001_!A\u0011(a%\u0011\u0002\u0003\u0007q\u0006\u0003\u0005>\u0003'\u0003\n\u00111\u00010\u0011)\t)+! \u0002\u0002\u0013\u0005\u0015qU\u0001\bk:\f\u0007\u000f\u001d7z)\u0011\tI+!.\u0011\u000b5\tY+a,\n\u0007\u00055fB\u0001\u0004PaRLwN\u001c\t\t\u001b\u0005E\u0006dL\u00180_%\u0019\u00111\u0017\b\u0003\rQ+\b\u000f\\36\u0011%\t9,a)\u0002\u0002\u0003\u00071)A\u0002yIAB!\"a/\u0002~E\u0005I\u0011AA\u0006\u0003m!C.Z:tS:LG\u000fJ4sK\u0006$XM\u001d\u0013eK\u001a\fW\u000f\u001c;%g!Q\u0011qXA?#\u0003%\t!a\u0003\u00027\u0011bWm]:j]&$He\u001a:fCR,'\u000f\n3fM\u0006,H\u000e\u001e\u00135\u0011)\t\u0019-! \u0012\u0002\u0013\u0005\u00111B\u0001\u001cI1,7o]5oSR$sM]3bi\u0016\u0014H\u0005Z3gCVdG\u000fJ\u001b\t\u0015\u0005\u001d\u0017QPI\u0001\n\u0003\tY!A\bbaBd\u0017\u0010\n3fM\u0006,H\u000e\u001e\u00134\u0011)\tY-! \u0012\u0002\u0013\u0005\u00111B\u0001\u0010CB\u0004H.\u001f\u0013eK\u001a\fW\u000f\u001c;%i!Q\u0011qZA?#\u0003%\t!a\u0003\u0002\u001f\u0005\u0004\b\u000f\\=%I\u00164\u0017-\u001e7uIUB!\"a5\u0002~\u0005\u0005I\u0011BAk\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005]\u0007\u0003BA\u0012\u00033LA!a7\u0002&\t1qJ\u00196fGR\u0004")
/* loaded from: input_file:scorch/optim/Adam.class */
public class Adam extends Optimizer implements Product, Serializable {
    private final Seq<Variable> parameters;
    private final double lr;
    private final double beta1;
    private final double beta2;
    private final double epsilon;
    private final Seq<Tensor> ms;
    private final Seq<Tensor> vs;
    private int t;

    public static Option<Tuple5<Seq<Variable>, Object, Object, Object, Object>> unapply(Adam adam) {
        return Adam$.MODULE$.unapply(adam);
    }

    public static Adam apply(Seq<Variable> seq, double d, double d2, double d3, double d4) {
        return Adam$.MODULE$.apply(seq, d, d2, d3, d4);
    }

    public static Function1<Tuple5<Seq<Variable>, Object, Object, Object, Object>, Adam> tupled() {
        return Adam$.MODULE$.tupled();
    }

    public static Function1<Seq<Variable>, Function1<Object, Function1<Object, Function1<Object, Function1<Object, Adam>>>>> curried() {
        return Adam$.MODULE$.curried();
    }

    public Seq<Variable> parameters() {
        return this.parameters;
    }

    public double lr() {
        return this.lr;
    }

    public double beta1() {
        return this.beta1;
    }

    public double beta2() {
        return this.beta2;
    }

    public double epsilon() {
        return this.epsilon;
    }

    public Seq<Tensor> ms() {
        return this.ms;
    }

    public Seq<Tensor> vs() {
        return this.vs;
    }

    public int t() {
        return this.t;
    }

    public void t_$eq(int i) {
        this.t = i;
    }

    @Override // scorch.optim.Optimizer
    public void step() {
        ((IterableLike) ((IterableLike) parameters().zip(ms(), Seq$.MODULE$.canBuildFrom())).zip(vs(), Seq$.MODULE$.canBuildFrom())).foreach(tuple2 -> {
            $anonfun$step$1(this, tuple2);
            return BoxedUnit.UNIT;
        });
    }

    public Adam copy(Seq<Variable> seq, double d, double d2, double d3, double d4) {
        return new Adam(seq, d, d2, d3, d4);
    }

    public Seq<Variable> copy$default$1() {
        return parameters();
    }

    public double copy$default$2() {
        return lr();
    }

    public double copy$default$3() {
        return beta1();
    }

    public double copy$default$4() {
        return beta2();
    }

    public double copy$default$5() {
        return epsilon();
    }

    public String productPrefix() {
        return "Adam";
    }

    public int productArity() {
        return 5;
    }

    public Object productElement(int i) {
        switch (i) {
            case 0:
                return parameters();
            case 1:
                return BoxesRunTime.boxToDouble(lr());
            case 2:
                return BoxesRunTime.boxToDouble(beta1());
            case 3:
                return BoxesRunTime.boxToDouble(beta2());
            case 4:
                return BoxesRunTime.boxToDouble(epsilon());
            default:
                throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger(i).toString());
        }
    }

    public Iterator<Object> productIterator() {
        return ScalaRunTime$.MODULE$.typedProductIterator(this);
    }

    public boolean canEqual(Object obj) {
        return obj instanceof Adam;
    }

    public int hashCode() {
        return Statics.finalizeHash(Statics.mix(Statics.mix(Statics.mix(Statics.mix(Statics.mix(-889275714, Statics.anyHash(parameters())), Statics.doubleHash(lr())), Statics.doubleHash(beta1())), Statics.doubleHash(beta2())), Statics.doubleHash(epsilon())), 5);
    }

    public String toString() {
        return ScalaRunTime$.MODULE$._toString(this);
    }

    public boolean equals(Object obj) {
        boolean z;
        if (this != obj) {
            if (obj instanceof Adam) {
                Adam adam = (Adam) obj;
                Seq<Variable> parameters = parameters();
                Seq<Variable> parameters2 = adam.parameters();
                if (parameters != null ? parameters.equals(parameters2) : parameters2 == null) {
                    if (lr() == adam.lr() && beta1() == adam.beta1() && beta2() == adam.beta2() && epsilon() == adam.epsilon() && adam.canEqual(this)) {
                        z = true;
                        if (!z) {
                        }
                    }
                }
                z = false;
                if (!z) {
                }
            }
            return false;
        }
        return true;
    }

    public static final /* synthetic */ void $anonfun$step$1(Adam adam, Tuple2 tuple2) {
        if (tuple2 != null) {
            Tuple2 tuple22 = (Tuple2) tuple2._1();
            Tensor tensor = (Tensor) tuple2._2();
            if (tuple22 != null) {
                Variable variable = (Variable) tuple22._1();
                Tensor tensor2 = (Tensor) tuple22._2();
                Tensor data = variable.data();
                Tensor data2 = variable.grad().data();
                tensor2.$times$eq(adam.beta1());
                tensor2.$plus$eq(package$.MODULE$.NumscaDoubleOps(1 - adam.beta1()).$times(data2));
                Tensor $div = tensor2.$div(1 - scala.math.package$.MODULE$.pow(adam.beta1(), adam.t()));
                tensor.$times$eq(adam.beta2());
                tensor.$plus$eq(package$.MODULE$.NumscaDoubleOps(1 - adam.beta2()).$times(package$.MODULE$.square(data2)));
                data.$minus$eq(package$.MODULE$.NumscaDoubleOps(adam.lr()).$times($div).$div(package$.MODULE$.sqrt(tensor.$div(1 - scala.math.package$.MODULE$.pow(adam.beta2(), adam.t()))).$plus(adam.epsilon())));
                adam.t_$eq(adam.t() + 1);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                return;
            }
        }
        throw new MatchError(tuple2);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public Adam(Seq<Variable> seq, double d, double d2, double d3, double d4) {
        super(seq);
        this.parameters = seq;
        this.lr = d;
        this.beta1 = d2;
        this.beta2 = d3;
        this.epsilon = d4;
        Product.$init$(this);
        this.ms = (Seq) seq.map(variable -> {
            return package$.MODULE$.zeros(variable.shape());
        }, Seq$.MODULE$.canBuildFrom());
        this.vs = (Seq) seq.map(variable2 -> {
            return package$.MODULE$.zeros(variable2.shape());
        }, Seq$.MODULE$.canBuildFrom());
        this.t = 1;
    }
}
