package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractSupervisedBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.algorithm.ensemble.WeightedBinaryEnsemble;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.DefaultWeightedValue;

@PublicationReference(title = "Large Margin Classification Using the Perceptron Algorithm", author = {"Yoav Freund", "Robert E. Schapire"}, year = 1999, type = PublicationType.Journal, publication = "Machine Learning", pages = {277, 296}, url = "http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.48.8200")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/perceptron/OnlineVotedPerceptron.class */
public class OnlineVotedPerceptron extends AbstractSupervisedBatchAndIncrementalLearner<Vectorizable, Boolean, WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer>> implements VectorFactoryContainer {
    protected VectorFactory<?> vectorFactory;

    public OnlineVotedPerceptron() {
        this(VectorFactory.getDenseDefault());
    }

    public OnlineVotedPerceptron(VectorFactory<?> vectorFactory) {
        setVectorFactory(vectorFactory);
    }

    @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer> createInitialLearnedObject() {
        return new WeightedBinaryEnsemble<>();
    }

    @Override // gov.sandia.cognition.learning.algorithm.SupervisedIncrementalLearner
    public void update(WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer> weightedBinaryEnsemble, Vectorizable vectorizable, Boolean bool) {
        if (vectorizable == null || bool == null) {
            return;
        }
        update(weightedBinaryEnsemble, vectorizable.convertToVector(), bool.booleanValue());
    }

    public void update(WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer> weightedBinaryEnsemble, Vector vector, boolean z) {
        double evaluateAsDouble = weightedBinaryEnsemble.evaluateAsDouble(vector);
        DefaultWeightedValue<LinearBinaryCategorizer> lastMember = getLastMember(weightedBinaryEnsemble);
        if ((z && evaluateAsDouble > 0.0d) || (!z && evaluateAsDouble < 0.0d)) {
            lastMember.setWeight(lastMember.getWeight() + 1.0d);
            return;
        }
        LinearBinaryCategorizer linearBinaryCategorizer = lastMember == null ? new LinearBinaryCategorizer(getVectorFactory().createVector(vector.getDimensionality()), 0.0d) : ((LinearBinaryCategorizer) lastMember.getValue()).m167clone();
        if (z) {
            linearBinaryCategorizer.getWeights().plusEquals(vector);
            linearBinaryCategorizer.setBias(linearBinaryCategorizer.getBias() + 1.0d);
        } else {
            linearBinaryCategorizer.getWeights().minusEquals(vector);
            linearBinaryCategorizer.setBias(linearBinaryCategorizer.getBias() - 1.0d);
        }
        weightedBinaryEnsemble.add(linearBinaryCategorizer, 1.0d);
    }

    public static DefaultWeightedValue<LinearBinaryCategorizer> getLastMember(WeightedBinaryEnsemble<Vectorizable, LinearBinaryCategorizer> weightedBinaryEnsemble) {
        int size = weightedBinaryEnsemble.getMembers().size();
        if (size <= 0) {
            return null;
        }
        return weightedBinaryEnsemble.getMembers().get(size - 1);
    }

    public VectorFactory<?> getVectorFactory() {
        return this.vectorFactory;
    }

    public void setVectorFactory(VectorFactory<?> vectorFactory) {
        this.vectorFactory = vectorFactory;
    }
}
