package jigg.ml;

import jigg.ml.LinearClassifier;
import jigg.ml.LogLinearClassifier;
import jigg.ml.OnlineLogLinearTrainer;
import jigg.ml.OnlineTrainer;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.math.Numeric$FloatIsFractional$;
import scala.reflect.ScalaSignature;
import scala.runtime.RichInt$;

/* compiled from: LogLinearAdaGradL1.scala */
@ScalaSignature(bytes = "\u0006\u0001E3Q!\u0001\u0002\u0002\u0002\u001d\u0011!\u0003T8h\u0019&tW-\u0019:BI\u0006<%/\u00193Mc)\u00111\u0001B\u0001\u0003[2T\u0011!B\u0001\u0005U&<wm\u0001\u0001\u0016\u0005!)2c\u0001\u0001\n\u001fA\u0011!\"D\u0007\u0002\u0017)\tA\"A\u0003tG\u0006d\u0017-\u0003\u0002\u000f\u0017\t1\u0011I\\=SK\u001a\u00042\u0001E\t\u0014\u001b\u0005\u0011\u0011B\u0001\n\u0003\u0005Yye\u000e\\5oK2{w\rT5oK\u0006\u0014HK]1j]\u0016\u0014\bC\u0001\u000b\u0016\u0019\u0001!QA\u0006\u0001C\u0002]\u0011\u0011\u0001T\t\u00031m\u0001\"AC\r\n\u0005iY!a\u0002(pi\"Lgn\u001a\t\u0003\u0015qI!!H\u0006\u0003\u0007\u0005s\u0017\u0010\u0003\u0005 \u0001\t\u0015\r\u0011\"\u0001!\u0003\u0019a\u0017-\u001c2eCV\t\u0011\u0005\u0005\u0002\u000bE%\u00111e\u0003\u0002\u0006\r2|\u0017\r\u001e\u0005\tK\u0001\u0011\t\u0011)A\u0005C\u00059A.Y7cI\u0006\u0004\u0003\u0002C\u0014\u0001\u0005\u000b\u0007I\u0011\u0001\u0011\u0002\u0007\u0015$\u0018\r\u0003\u0005*\u0001\t\u0005\t\u0015!\u0003\"\u0003\u0011)G/\u0019\u0011\t\u000b-\u0002A\u0011\u0001\u0017\u0002\rqJg.\u001b;?)\ricf\f\t\u0004!\u0001\u0019\u0002\"B\u0010+\u0001\u0004\t\u0003\"B\u0014+\u0001\u0004\t\u0003BB\u0019\u0001A\u0003%!'A\u0006mCN$X\u000b\u001d3bi\u0016\u001c\bc\u0001\t4C%\u0011AG\u0001\u0002\u0015\u000fJ|w/\u00192mK^+\u0017n\u001a5u-\u0016\u001cGo\u001c:\t\rY\u0002\u0001\u0015!\u00033\u0003\u0019!\u0017.Y4Hi\")\u0001\b\u0001C)s\u00051q/Z5hQR$\"!\t\u001e\t\u000bm:\u0004\u0019\u0001\u001f\u0002\u0007%$\u0007\u0010\u0005\u0002\u000b{%\u0011ah\u0003\u0002\u0004\u0013:$\b\"\u0002!\u0001\t\u0003\n\u0015\u0001F;qI\u0006$X-\u0012=b[BdWmV3jO\"$8\u000f\u0006\u0003C\u000b*c\u0005C\u0001\u0006D\u0013\t!5B\u0001\u0003V]&$\b\"\u0002$@\u0001\u00049\u0015!A3\u0011\u0007AA5#\u0003\u0002J\u0005\t9Q\t_1na2,\u0007\"B&@\u0001\u0004\u0019\u0012\u0001B4pY\u0012DQ!T A\u0002\u0005\n!\u0002Z3sSZ\fG/\u001b<f\u0011\u0015y\u0005\u0001\"\u0011Q\u0003-\u0001xn\u001d;Qe>\u001cWm]:\u0016\u0003\t\u0003")
/* loaded from: input_file:jigg/ml/LogLinearAdaGradL1.class */
public abstract class LogLinearAdaGradL1<L> implements OnlineLogLinearTrainer<L> {
    private final float lambda;
    private final float eta;
    private final GrowableWeightVector<Object> lastUpdates;
    private final GrowableWeightVector<Object> diagGt;
    private int time;

    @Override // jigg.ml.OnlineLogLinearTrainer
    public int time() {
        return this.time;
    }

    @Override // jigg.ml.OnlineLogLinearTrainer
    public void time_$eq(int i) {
        this.time = i;
    }

    @Override // jigg.ml.OnlineLogLinearTrainer, jigg.ml.OnlineTrainer
    public void update(Seq<Example<L>> seq, L l) {
        OnlineLogLinearTrainer.Cclass.update(this, seq, l);
    }

    @Override // jigg.ml.OnlineLogLinearTrainer
    public void reguralizeWeights(Seq<Example<L>> seq) {
        OnlineLogLinearTrainer.Cclass.reguralizeWeights(this, seq);
    }

    @Override // jigg.ml.LogLinearClassifier
    public float[] labelProbs(Seq<Example<L>> seq) {
        return LogLinearClassifier.Cclass.labelProbs(this, seq);
    }

    @Override // jigg.ml.Classifier
    public Tuple2<L, Object> predict(Seq<Example<L>> seq) {
        return LinearClassifier.Cclass.predict(this, seq);
    }

    @Override // jigg.ml.LinearClassifier
    public float featureScore(int[] iArr) {
        return LinearClassifier.Cclass.featureScore(this, iArr);
    }

    public float lambda() {
        return this.lambda;
    }

    public float eta() {
        return this.eta;
    }

    @Override // jigg.ml.LinearClassifier
    public float weight(int i) {
        if (this.lastUpdates.apply$mcF$sp(i) == time()) {
            return weights().apply$mcF$sp(i);
        }
        if (weights().apply$mcF$sp(i) == 0.0f) {
            return 0.0f;
        }
        float apply$mcF$sp = this.lastUpdates.apply$mcF$sp(i);
        Predef$.MODULE$.assert(time() != 0);
        double signum = Math.signum(r0) * Math.max(0.0d, Math.abs(r0) - (((lambda() * eta()) / (1.0d + Math.sqrt(this.diagGt.apply$mcF$sp(i)))) * (time() - apply$mcF$sp)));
        weights().update$mcF$sp(i, (float) signum);
        this.lastUpdates.update$mcF$sp(i, time());
        return (float) signum;
    }

    @Override // jigg.ml.OnlineLogLinearTrainer
    public void updateExampleWeights(Example<L> example, L l, float f) {
        float f2 = -f;
        float f3 = f2 * f2;
        int[] featVec = example.featVec();
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= Predef$.MODULE$.intArrayOps(featVec).size()) {
                return;
            }
            int i3 = featVec[i2];
            float apply$mcF$sp = weights().apply$mcF$sp(i3);
            this.diagGt.update$mcF$sp(i3, this.diagGt.apply$mcF$sp(i3) + f3);
            double eta = eta() / (1.0d + Math.sqrt(this.diagGt.apply$mcF$sp(i3)));
            double d = apply$mcF$sp - (eta * f2);
            weights().update$mcF$sp(i3, (float) (Math.signum(d) * Math.max(0.0d, Math.abs(d) - (lambda() * eta))));
            this.lastUpdates.update$mcF$sp(i3, time() + 1);
            i = i2 + 1;
        }
    }

    @Override // jigg.ml.OnlineTrainer
    public void postProcess() {
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), weights().size()).foreach(new LogLinearAdaGradL1$$anonfun$postProcess$1(this));
    }

    public LogLinearAdaGradL1(float f, float f2) {
        this.lambda = f;
        this.eta = f2;
        OnlineTrainer.Cclass.$init$(this);
        LinearClassifier.Cclass.$init$(this);
        LogLinearClassifier.Cclass.$init$(this);
        time_$eq(0);
        this.lastUpdates = WeightVector$.MODULE$.growable(WeightVector$.MODULE$.growable$default$1(), Numeric$FloatIsFractional$.MODULE$);
        this.diagGt = WeightVector$.MODULE$.growable(WeightVector$.MODULE$.growable$default$1(), Numeric$FloatIsFractional$.MODULE$);
    }
}
