package gov.sandia.cognition.learning.algorithm.confidence;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractSupervisedBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.function.categorization.DefaultConfidenceWeightedBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ArgumentChecker;

@PublicationReference(author = {"Koby Crammer", "Alex Kulesza", "Mark Dredze"}, title = "Adaptive Regularization of Weight Vectors", year = 2009, type = PublicationType.Conference, publication = "Advances in Neural Information Processing Systems", url = "http://papers.nips.cc/paper/3848-adaptive-regularization-of-weight-vectors.pdf")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/confidence/AdaptiveRegularizationOfWeights.class */
public class AdaptiveRegularizationOfWeights extends AbstractSupervisedBatchAndIncrementalLearner<Vectorizable, Boolean, DefaultConfidenceWeightedBinaryCategorizer> {
    public static final double DEFAULT_R = 0.001d;
    protected double r;

    public AdaptiveRegularizationOfWeights() {
        this(0.001d);
    }

    public AdaptiveRegularizationOfWeights(double d) {
        setR(d);
    }

    @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public DefaultConfidenceWeightedBinaryCategorizer createInitialLearnedObject() {
        return new DefaultConfidenceWeightedBinaryCategorizer();
    }

    @Override // gov.sandia.cognition.learning.algorithm.SupervisedIncrementalLearner
    public void update(DefaultConfidenceWeightedBinaryCategorizer defaultConfidenceWeightedBinaryCategorizer, Vectorizable vectorizable, Boolean bool) {
        if (vectorizable == null || bool == null) {
            return;
        }
        update(defaultConfidenceWeightedBinaryCategorizer, vectorizable.convertToVector(), bool.booleanValue());
    }

    public void update(DefaultConfidenceWeightedBinaryCategorizer defaultConfidenceWeightedBinaryCategorizer, Vector vector, boolean z) {
        Vector mean;
        Matrix mo163getCovariance;
        if (defaultConfidenceWeightedBinaryCategorizer.isInitialized()) {
            mean = defaultConfidenceWeightedBinaryCategorizer.getMean();
            mo163getCovariance = defaultConfidenceWeightedBinaryCategorizer.mo163getCovariance();
        } else {
            int dimensionality = vector.getDimensionality();
            mean = VectorFactory.getDenseDefault().createVector(dimensionality);
            mo163getCovariance = MatrixFactory.getDenseDefault().createIdentity(dimensionality, dimensionality);
            defaultConfidenceWeightedBinaryCategorizer.setMean(mean);
            defaultConfidenceWeightedBinaryCategorizer.setCovariance(mo163getCovariance);
        }
        double dotProduct = vector.dotProduct(mean);
        double d = z ? 1.0d : -1.0d;
        double d2 = d * dotProduct;
        if (d2 < 1.0d) {
            Vector times = vector.times(mo163getCovariance);
            double dotProduct2 = 1.0d / (times.dotProduct(vector) + this.r);
            double max = Math.max(0.0d, 1.0d - d2) * dotProduct2;
            Vector times2 = vector.times(mo163getCovariance);
            times2.scaleEquals(max * d);
            mean.plusEquals(times2);
            Matrix outerProduct = times.outerProduct(times);
            outerProduct.scaleEquals(-dotProduct2);
            mo163getCovariance.plusEquals(outerProduct);
        }
    }

    public double getR() {
        return this.r;
    }

    public void setR(double d) {
        ArgumentChecker.assertIsPositive("r", d);
        this.r = d;
    }
}
