package edu.arizona.sista.learning;

import edu.arizona.sista.learning.RankingClassifier;
import edu.arizona.sista.struct.Counter;
import edu.arizona.sista.struct.Counters$;
import edu.arizona.sista.struct.Lexicon;
import edu.arizona.sista.struct.Lexicon$;
import edu.arizona.sista.utils.StringUtils$;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.Properties;
import org.slf4j.Logger;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BooleanRef;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.util.Random;

/* compiled from: PerceptronRankingClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\t]d\u0001B\u0001\u0003\u0001-\u00111\u0004U3sG\u0016\u0004HO]8o%\u0006t7.\u001b8h\u00072\f7o]5gS\u0016\u0014(BA\u0002\u0005\u0003!aW-\u0019:oS:<'BA\u0003\u0007\u0003\u0015\u0019\u0018n\u001d;b\u0015\t9\u0001\"A\u0004be&TxN\\1\u000b\u0003%\t1!\u001a3v\u0007\u0001)\"\u0001D\r\u0014\t\u0001i1C\t\t\u0003\u001dEi\u0011a\u0004\u0006\u0002!\u0005)1oY1mC&\u0011!c\u0004\u0002\u0007\u0003:L(+\u001a4\u0011\u0007Q)r#D\u0001\u0003\u0013\t1\"AA\tSC:\\\u0017N\\4DY\u0006\u001c8/\u001b4jKJ\u0004\"\u0001G\r\r\u0001\u0011)!\u0004\u0001b\u00017\t\ta)\u0005\u0002\u001d?A\u0011a\"H\u0005\u0003==\u0011qAT8uQ&tw\r\u0005\u0002\u000fA%\u0011\u0011e\u0004\u0002\u0004\u0003:L\bC\u0001\b$\u0013\t!sB\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0003\u0005'\u0001\t\u0015\r\u0011\"\u0001(\u0003\u0019)\u0007o\\2igV\t\u0001\u0006\u0005\u0002\u000fS%\u0011!f\u0004\u0002\u0004\u0013:$\b\u0002\u0003\u0017\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u0015\u0002\u000f\u0015\u0004xn\u00195tA!Aa\u0006\u0001BC\u0002\u0013\u0005q%\u0001\tckJt\u0017J\\%uKJ\fG/[8og\"A\u0001\u0007\u0001B\u0001B\u0003%\u0001&A\tckJt\u0017J\\%uKJ\fG/[8og\u0002B\u0001B\r\u0001\u0003\u0006\u0004%\taM\u0001\f[\u0006\u0014x-\u001b8SCRLw.F\u00015!\tqQ'\u0003\u00027\u001f\t1Ai\\;cY\u0016D\u0001\u0002\u000f\u0001\u0003\u0002\u0003\u0006I\u0001N\u0001\r[\u0006\u0014x-\u001b8SCRLw\u000e\t\u0005\u0006u\u0001!\taO\u0001\u0007y%t\u0017\u000e\u001e \u0015\tqjdh\u0010\t\u0004)\u00019\u0002b\u0002\u0014:!\u0003\u0005\r\u0001\u000b\u0005\b]e\u0002\n\u00111\u0001)\u0011\u001d\u0011\u0014\b%AA\u0002QBQA\u000f\u0001\u0005\u0002\u0005#\"\u0001\u0010\"\t\u000b\r\u0003\u0005\u0019\u0001#\u0002\u000bA\u0014x\u000e]:\u0011\u0005\u0015SU\"\u0001$\u000b\u0005\u001dC\u0015\u0001B;uS2T\u0011!S\u0001\u0005U\u00064\u0018-\u0003\u0002L\r\nQ\u0001K]8qKJ$\u0018.Z:\t\u000f5\u0003\u0001\u0019!C\u0001\u001d\u0006qa-Z1ukJ,G*\u001a=jG>tW#A(\u0011\u0007A\u001bv#D\u0001R\u0015\t\u0011F!\u0001\u0004tiJ,8\r^\u0005\u0003)F\u0013q\u0001T3yS\u000e|g\u000eC\u0004W\u0001\u0001\u0007I\u0011A,\u0002%\u0019,\u0017\r^;sK2+\u00070[2p]~#S-\u001d\u000b\u00031n\u0003\"AD-\n\u0005i{!\u0001B+oSRDq\u0001X+\u0002\u0002\u0003\u0007q*A\u0002yIEBaA\u0018\u0001!B\u0013y\u0015a\u00044fCR,(/\u001a'fq&\u001cwN\u001c\u0011\t\u000f\u0001\u0004\u0001\u0019!C\u0001C\u00069q/Z5hQR\u001cX#\u00012\u0011\u00079\u0019G'\u0003\u0002e\u001f\t)\u0011I\u001d:bs\"9a\r\u0001a\u0001\n\u00039\u0017aC<fS\u001eDGo]0%KF$\"\u0001\u00175\t\u000fq+\u0017\u0011!a\u0001E\"1!\u000e\u0001Q!\n\t\f\u0001b^3jO\"$8\u000f\t\u0005\bY\u0002\u0001\r\u0011\"\u0001(\u0003I\u0019XO\u001d<jm\u0016$\u0017\n^3sCRLwN\\:\t\u000f9\u0004\u0001\u0019!C\u0001_\u000612/\u001e:wSZ,G-\u0013;fe\u0006$\u0018n\u001c8t?\u0012*\u0017\u000f\u0006\u0002Ya\"9A,\\A\u0001\u0002\u0004A\u0003B\u0002:\u0001A\u0003&\u0001&A\ntkJ4\u0018N^3e\u0013R,'/\u0019;j_:\u001c\b\u0005C\u0004u\u0001\u0001\u0007I\u0011A1\u0002\u0015\u00054xmV3jO\"$8\u000fC\u0004w\u0001\u0001\u0007I\u0011A<\u0002\u001d\u00054xmV3jO\"$8o\u0018\u0013fcR\u0011\u0001\f\u001f\u0005\b9V\f\t\u00111\u0001c\u0011\u0019Q\b\u0001)Q\u0005E\u0006Y\u0011M^4XK&<\u0007\u000e^:!\u0011\u001da\b\u00011A\u0005\u0002\u001d\nA\u0002^8uC2\fV/\u001a:jKNDqA \u0001A\u0002\u0013\u0005q0\u0001\tu_R\fG.U;fe&,7o\u0018\u0013fcR\u0019\u0001,!\u0001\t\u000fqk\u0018\u0011!a\u0001Q!9\u0011Q\u0001\u0001!B\u0013A\u0013!\u0004;pi\u0006d\u0017+^3sS\u0016\u001c\b\u0005\u0003\u0005\u0002\n\u0001\u0001\r\u0011\"\u0001(\u00031!x\u000e^1m+B$\u0017\r^3t\u0011%\ti\u0001\u0001a\u0001\n\u0003\ty!\u0001\tu_R\fG.\u00169eCR,7o\u0018\u0013fcR\u0019\u0001,!\u0005\t\u0011q\u000bY!!AA\u0002!Bq!!\u0006\u0001A\u0003&\u0001&A\u0007u_R\fG.\u00169eCR,7\u000f\t\u0005\t\u00033\u0001\u0001\u0019!C\u0001O\u0005yQ\u000f\u001d3bi\u0016\u001c\b+\u001a:Fa>\u001c\u0007\u000eC\u0005\u0002\u001e\u0001\u0001\r\u0011\"\u0001\u0002 \u0005\u0019R\u000f\u001d3bi\u0016\u001c\b+\u001a:Fa>\u001c\u0007n\u0018\u0013fcR\u0019\u0001,!\t\t\u0011q\u000bY\"!AA\u0002!Bq!!\n\u0001A\u0003&\u0001&\u0001\tva\u0012\fG/Z:QKJ,\u0005o\\2iA!A\u0011\u0011\u0006\u0001A\u0002\u0013\u00051'\u0001\u0004nCJ<\u0017N\u001c\u0005\n\u0003[\u0001\u0001\u0019!C\u0001\u0003_\t!\"\\1sO&tw\fJ3r)\rA\u0016\u0011\u0007\u0005\t9\u0006-\u0012\u0011!a\u0001i!9\u0011Q\u0007\u0001!B\u0013!\u0014aB7be\u001eLg\u000e\t\u0005\b\u0003s\u0001A\u0011IA\u001e\u0003\u0015!(/Y5o)\u0015A\u0016QHA$\u0011!\ty$a\u000eA\u0002\u0005\u0005\u0013a\u00023bi\u0006\u001cX\r\u001e\t\u0005)\u0005\rs#C\u0002\u0002F\t\u0011aBU1oW&tw\rR1uCN,G\u000f\u0003\u0006\u0002J\u0005]\u0002\u0013!a\u0001\u0003\u0017\nQa\u001d9b]N\u0004RADA'\u0003#J1!a\u0014\u0010\u0005\u0019y\u0005\u000f^5p]B1\u00111KA2\u0003SrA!!\u0016\u0002`9!\u0011qKA/\u001b\t\tIFC\u0002\u0002\\)\ta\u0001\u0010:p_Rt\u0014\"\u0001\t\n\u0007\u0005\u0005t\"A\u0004qC\u000e\\\u0017mZ3\n\t\u0005\u0015\u0014q\r\u0002\t\u0013R,'/\u00192mK*\u0019\u0011\u0011M\b\u0011\u000b9\tY\u0007\u000b\u0015\n\u0007\u00055tB\u0001\u0004UkBdWM\r\u0005\b\u0003c\u0002A\u0011AA:\u0003i\u0019w.\u001c9vi\u0016\fe/\u001a:bO\u00164Vm\u0019;pe2+gn\u001a;i)\u0015!\u0014QOA<\u0011!\ty$a\u001cA\u0002\u0005\u0005\u0003\u0002CA=\u0003_\u0002\r!a\u001f\u0002\u000f%tG-[2fgB\u0019ab\u0019\u0015\t\u000f\u0005}\u0004\u0001\"\u0001\u0002\u0002\u00061Q\u000f\u001d3bi\u0016$R\u0001WAB\u0003\u001bC\u0001\"!\"\u0002~\u0001\u0007\u0011qQ\u0001\u0007E\u0016$H/\u001a:\u0011\tA\u000bI\tK\u0005\u0004\u0003\u0017\u000b&aB\"pk:$XM\u001d\u0005\t\u0003\u001f\u000bi\b1\u0001\u0002\b\u0006)qo\u001c:tK\"9\u00111\u0013\u0001\u0005\u0002\u0005U\u0015\u0001C1eIR{\u0017I^4\u0015\u0003aCq!!'\u0001\t\u0003\tY*A\u0007va\u0012\fG/Z,fS\u001eDGo\u001d\u000b\u00061\u0006u\u0015\u0011\u0015\u0005\t\u0003?\u000b9\n1\u0001\u0002\b\u0006\ta\u000fC\u0004\u0002$\u0006]\u0005\u0019\u0001\u001b\u0002\u0003]Dq!a*\u0001\t\u0003\nI+\u0001\u0005tG>\u0014Xm](g)\u0011\tY+!,\u0011\u000b\u0005M\u00131\r\u001b\t\u0011\u0005=\u0016Q\u0015a\u0001\u0003c\u000b1\"];fef$\u0015\r^;ngB1\u00111KA2\u0003g\u0003R\u0001FA[Q]I1!a.\u0003\u0005\u0015!\u0015\r^;n\u0011\u001d\tY\f\u0001C\u0001\u0003{\u000bq\u0002Z1uk6$u\u000e\u001e)s_\u0012,8\r\u001e\u000b\u0004i\u0005}\u0006\u0002CAa\u0003s\u0003\r!a1\u0002\u0003\r\u0004B\u0001UAE/!9\u0011q\u0019\u0001\u0005B\u0005%\u0017AB:bm\u0016$v\u000eF\u0002Y\u0003\u0017D\u0001\"!4\u0002F\u0002\u0007\u0011qZ\u0001\tM&dWMT1nKB!\u0011\u0011[Al\u001d\rq\u00111[\u0005\u0004\u0003+|\u0011A\u0002)sK\u0012,g-\u0003\u0003\u0002Z\u0006m'AB*ue&twMC\u0002\u0002V>Aq!a8\u0001\t\u0003\n\t/\u0001\u0007eSN\u0004H.Y=N_\u0012,G\u000eF\u0002Y\u0003GD\u0001\"!:\u0002^\u0002\u0007\u0011q]\u0001\u0003a^\u0004B!!;\u0002p6\u0011\u00111\u001e\u0006\u0004\u0003[D\u0015AA5p\u0013\u0011\t\t0a;\u0003\u0017A\u0013\u0018N\u001c;Xe&$XM\u001d\u0005\n\u0003k\u0004\u0011\u0013!C!\u0003o\fq\u0002\u001e:bS:$C-\u001a4bk2$HEM\u000b\u0003\u0003sTC!a\u0013\u0002|.\u0012\u0011Q \t\u0005\u0003\u007f\u0014I!\u0004\u0002\u0003\u0002)!!1\u0001B\u0003\u0003%)hn\u00195fG.,GMC\u0002\u0003\b=\t!\"\u00198o_R\fG/[8o\u0013\u0011\u0011YA!\u0001\u0003#Ut7\r[3dW\u0016$g+\u0019:jC:\u001cWmB\u0004\u0003\u0010\tA\tA!\u0005\u00027A+'oY3qiJ|gNU1oW&twm\u00117bgNLg-[3s!\r!\"1\u0003\u0004\u0007\u0003\tA\tA!\u0006\u0014\t\tMQB\t\u0005\bu\tMA\u0011\u0001B\r)\t\u0011\t\u0002\u0003\u0006\u0003\u001e\tM!\u0019!C\u0001\u0005?\ta\u0001\\8hO\u0016\u0014XC\u0001B\u0011!\u0011\u0011\u0019C!\f\u000e\u0005\t\u0015\"\u0002\u0002B\u0014\u0005S\tQa\u001d7gi)T!Aa\u000b\u0002\u0007=\u0014x-\u0003\u0003\u00030\t\u0015\"A\u0002'pO\u001e,'\u000fC\u0005\u00034\tM\u0001\u0015!\u0003\u0003\"\u00059An\\4hKJ\u0004\u0003\u0002\u0003B\u001c\u0005'!\tA!\u000f\u0002\u00111|\u0017\r\u001a$s_6,BAa\u000f\u0003BQ!!Q\bB\"!\u0011!\u0002Aa\u0010\u0011\u0007a\u0011\t\u0005\u0002\u0004\u001b\u0005k\u0011\ra\u0007\u0005\t\u0003\u001b\u0014)\u00041\u0001\u0002P\"Q!q\tB\n#\u0003%\tA!\u0013\u00027\u0011bWm]:j]&$He\u001a:fCR,'\u000f\n3fM\u0006,H\u000e\u001e\u00132+\u0011\u0011YEa\u0014\u0016\u0005\t5#f\u0001\u0015\u0002|\u00121!D!\u0012C\u0002mA!Ba\u0015\u0003\u0014E\u0005I\u0011\u0001B+\u0003m!C.Z:tS:LG\u000fJ4sK\u0006$XM\u001d\u0013eK\u001a\fW\u000f\u001c;%eU!!1\nB,\t\u0019Q\"\u0011\u000bb\u00017!Q!1\fB\n#\u0003%\tA!\u0018\u00027\u0011bWm]:j]&$He\u001a:fCR,'\u000f\n3fM\u0006,H\u000e\u001e\u00134+\u0011\u0011yFa\u0019\u0016\u0005\t\u0005$f\u0001\u001b\u0002|\u00121!D!\u0017C\u0002mA!Ba\u001a\u0003\u0014\u0005\u0005I\u0011\u0002B5\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\t-\u0004\u0003\u0002B7\u0005gj!Aa\u001c\u000b\u0007\tE\u0004*\u0001\u0003mC:<\u0017\u0002\u0002B;\u0005_\u0012aa\u00142kK\u000e$\b")
/* loaded from: input_file:edu/arizona/sista/learning/PerceptronRankingClassifier.class */
public class PerceptronRankingClassifier<F> implements RankingClassifier<F>, Serializable {
    private final int epochs;
    private final int burnInIterations;
    private final double marginRatio;
    private Lexicon<F> featureLexicon;
    private double[] weights;
    private int survivedIterations;
    private double[] avgWeights;
    private int totalQueries;
    private int totalUpdates;
    private int updatesPerEpoch;
    private double margin;

    public static <F> PerceptronRankingClassifier<F> loadFrom(String str) {
        return PerceptronRankingClassifier$.MODULE$.loadFrom(str);
    }

    public static Logger logger() {
        return PerceptronRankingClassifier$.MODULE$.logger();
    }

    @Override // edu.arizona.sista.learning.RankingClassifier
    public Iterable<Object> probabilitiesOf(Iterable<Datum<Object, F>> iterable, double d) {
        return RankingClassifier.Cclass.probabilitiesOf(this, iterable, d);
    }

    @Override // edu.arizona.sista.learning.RankingClassifier
    public double probabilitiesOf$default$2() {
        return RankingClassifier.Cclass.probabilitiesOf$default$2(this);
    }

    public int epochs() {
        return this.epochs;
    }

    public int burnInIterations() {
        return this.burnInIterations;
    }

    public double marginRatio() {
        return this.marginRatio;
    }

    public Lexicon<F> featureLexicon() {
        return this.featureLexicon;
    }

    public void featureLexicon_$eq(Lexicon<F> lexicon) {
        this.featureLexicon = lexicon;
    }

    public double[] weights() {
        return this.weights;
    }

    public void weights_$eq(double[] dArr) {
        this.weights = dArr;
    }

    public int survivedIterations() {
        return this.survivedIterations;
    }

    public void survivedIterations_$eq(int i) {
        this.survivedIterations = i;
    }

    public double[] avgWeights() {
        return this.avgWeights;
    }

    public void avgWeights_$eq(double[] dArr) {
        this.avgWeights = dArr;
    }

    public int totalQueries() {
        return this.totalQueries;
    }

    public void totalQueries_$eq(int i) {
        this.totalQueries = i;
    }

    public int totalUpdates() {
        return this.totalUpdates;
    }

    public void totalUpdates_$eq(int i) {
        this.totalUpdates = i;
    }

    public int updatesPerEpoch() {
        return this.updatesPerEpoch;
    }

    public void updatesPerEpoch_$eq(int i) {
        this.updatesPerEpoch = i;
    }

    public double margin() {
        return this.margin;
    }

    public void margin_$eq(double d) {
        this.margin = d;
    }

    @Override // edu.arizona.sista.learning.RankingClassifier
    public void train(RankingDataset<F> rankingDataset, Option<Iterable<Tuple2<Object, Object>>> option) {
        int[] mkTrainIndices = Datasets$.MODULE$.mkTrainIndices(rankingDataset.size(), option);
        totalQueries_$eq(mkTrainIndices.length);
        featureLexicon_$eq(Lexicon$.MODULE$.apply(rankingDataset.featureLexicon()));
        weights_$eq(new double[featureLexicon().size()]);
        avgWeights_$eq(new double[featureLexicon().size()]);
        totalUpdates_$eq(0);
        double computeAverageVectorLength = computeAverageVectorLength(rankingDataset, mkTrainIndices);
        Datasets$.MODULE$.logger().debug(new StringBuilder().append("Average vector length in training: ").append(BoxesRunTime.boxToDouble(computeAverageVectorLength)).toString());
        if (marginRatio() > 0) {
            margin_$eq(marginRatio() * computeAverageVectorLength * computeAverageVectorLength);
        }
        Random random = new Random(1);
        BooleanRef create = BooleanRef.create(false);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), epochs()).withFilter(new PerceptronRankingClassifier$$anonfun$train$1(this, create)).foreach(new PerceptronRankingClassifier$$anonfun$train$2(this, rankingDataset, mkTrainIndices, random, create));
    }

    @Override // edu.arizona.sista.learning.RankingClassifier
    public Option<Iterable<Tuple2<Object, Object>>> train$default$2() {
        return None$.MODULE$;
    }

    public double computeAverageVectorLength(RankingDataset<F> rankingDataset, int[] iArr) {
        DoubleRef create = DoubleRef.create(0.0d);
        Predef$.MODULE$.intArrayOps(iArr).foreach(new PerceptronRankingClassifier$$anonfun$computeAverageVectorLength$1(this, rankingDataset, create, IntRef.create(0)));
        return create.elem / r0.elem;
    }

    public void update(Counter<Object> counter, Counter<Object> counter2) {
        if (Counters$.MODULE$.dotProduct(weights(), counter) - Counters$.MODULE$.dotProduct(weights(), counter2) > margin()) {
            survivedIterations_$eq(survivedIterations() + 1);
            return;
        }
        addToAvg();
        updateWeights(counter, 1.0d);
        updateWeights(counter2, -1.0d);
        survivedIterations_$eq(0);
        updatesPerEpoch_$eq(updatesPerEpoch() + 1);
    }

    public void addToAvg() {
        if (survivedIterations() <= 0 || totalUpdates() + updatesPerEpoch() <= burnInIterations()) {
            return;
        }
        double survivedIterations = survivedIterations() * totalQueries();
        for (int i = 0; i < weights().length; i++) {
            double[] avgWeights = avgWeights();
            int i2 = i;
            avgWeights[i2] = avgWeights[i2] + (weights()[i] * survivedIterations);
        }
    }

    public void updateWeights(Counter<Object> counter, double d) {
        counter.keySet().foreach(new PerceptronRankingClassifier$$anonfun$updateWeights$1(this, counter, d));
    }

    @Override // edu.arizona.sista.learning.RankingClassifier
    public Iterable<Object> scoresOf(Iterable<Datum<Object, F>> iterable) {
        ArrayBuffer arrayBuffer = new ArrayBuffer();
        iterable.foreach(new PerceptronRankingClassifier$$anonfun$scoresOf$1(this, arrayBuffer));
        return Predef$.MODULE$.wrapDoubleArray((double[]) arrayBuffer.toArray(ClassTag$.MODULE$.Double()));
    }

    public double datumDotProduct(Counter<F> counter) {
        DoubleRef create = DoubleRef.create(0.0d);
        counter.keySet().foreach(new PerceptronRankingClassifier$$anonfun$datumDotProduct$1(this, counter, create));
        return create.elem;
    }

    @Override // edu.arizona.sista.learning.RankingClassifier
    public void saveTo(String str) {
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
        objectOutputStream.writeObject(this);
        objectOutputStream.close();
    }

    @Override // edu.arizona.sista.learning.RankingClassifier
    public void displayModel(PrintWriter printWriter) {
        printWriter.println("Perceptron weights:");
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), avgWeights().length).foreach$mVc$sp(new PerceptronRankingClassifier$$anonfun$displayModel$1(this, printWriter));
    }

    public PerceptronRankingClassifier(int i, int i2, double d) {
        this.epochs = i;
        this.burnInIterations = i2;
        this.marginRatio = d;
        RankingClassifier.Cclass.$init$(this);
        this.featureLexicon = null;
        this.weights = null;
        this.survivedIterations = 0;
        this.avgWeights = null;
        this.totalQueries = 0;
        this.totalUpdates = 0;
        this.updatesPerEpoch = 0;
        this.margin = 0.0d;
    }

    public PerceptronRankingClassifier(Properties properties) {
        this(StringUtils$.MODULE$.getInt(properties, "epochs", 2), StringUtils$.MODULE$.getInt(properties, "burnInIterations", 0), StringUtils$.MODULE$.getDouble(properties, "marginRatio", 1.0d));
    }
}
