package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.LinearMultiCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Iterator;

@PublicationReference(title = "Ultraconservative Online Algorithms for Multiclass Problems", author = {"Koby Crammer", "Yoram Singer"}, year = 2003, type = PublicationType.Journal, publication = "Journal of Machine Learning Research", pages = {951, 991}, url = "http://portal.acm.org/citation.cfm?id=944936")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/perceptron/BatchMultiPerceptron.class */
public class BatchMultiPerceptron<CategoryType> extends AbstractAnytimeSupervisedBatchLearner<Vectorizable, CategoryType, LinearMultiCategorizer<CategoryType>> implements MeasurablePerformanceAlgorithm, VectorFactoryContainer {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_MIN_MARGIN = 0.0d;
    protected double minMargin;
    protected VectorFactory<?> vectorFactory;
    protected transient LinearMultiCategorizer<CategoryType> result;
    protected transient int errorCount;

    public BatchMultiPerceptron() {
        this(100);
    }

    public BatchMultiPerceptron(int i) {
        this(i, 0.0d);
    }

    public BatchMultiPerceptron(int i, double d) {
        this(i, d, VectorFactory.getDefault());
    }

    public BatchMultiPerceptron(int i, double d, VectorFactory<?> vectorFactory) {
        super(i);
        setMinMargin(d);
        setVectorFactory(vectorFactory);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        if (CollectionUtil.isEmpty(getData())) {
            return false;
        }
        int inputDimensionality = DatasetUtil.getInputDimensionality(getData());
        this.result = new LinearMultiCategorizer<>();
        Iterator it = DatasetUtil.findUniqueOutputs(getData()).iterator();
        while (it.hasNext()) {
            this.result.getPrototypes().put(it.next(), new LinearBinaryCategorizer(getVectorFactory().createVector(inputDimensionality), 0.0d));
        }
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        setErrorCount(0);
        for (InputOutputPair inputOutputPair : getData()) {
            if (inputOutputPair != null) {
                Vector convertToVector = ((Vectorizable) inputOutputPair.getInput()).convertToVector();
                Object output = inputOutputPair.getOutput();
                CategoryType categorytype = null;
                double d = Double.NEGATIVE_INFINITY;
                for (CategoryType categorytype2 : this.result.getCategories()) {
                    double evaluateAsDouble = this.result.evaluateAsDouble(convertToVector, (Vector) categorytype2);
                    if (this.minMargin != 0.0d && output.equals(categorytype2)) {
                        evaluateAsDouble -= this.minMargin;
                    }
                    if (evaluateAsDouble > d) {
                        categorytype = categorytype2;
                        d = evaluateAsDouble;
                    }
                }
                if (!ObjectUtil.equalsSafe(output, categorytype)) {
                    setErrorCount(getErrorCount() + 1);
                    LinearBinaryCategorizer linearBinaryCategorizer = this.result.getPrototypes().get(output);
                    linearBinaryCategorizer.getWeights().plusEquals(convertToVector);
                    linearBinaryCategorizer.setBias(linearBinaryCategorizer.getBias() + 1.0d);
                    LinearBinaryCategorizer linearBinaryCategorizer2 = this.result.getPrototypes().get(categorytype);
                    linearBinaryCategorizer2.getWeights().minusEquals(convertToVector);
                    linearBinaryCategorizer2.setBias(linearBinaryCategorizer2.getBias() - 1.0d);
                }
            }
        }
        return getErrorCount() > 0;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
    }

    /* renamed from: getResult, reason: merged with bridge method [inline-methods] */
    public LinearMultiCategorizer<CategoryType> m65getResult() {
        return this.result;
    }

    protected void setResult(LinearMultiCategorizer<CategoryType> linearMultiCategorizer) {
        this.result = linearMultiCategorizer;
    }

    public double getMinMargin() {
        return this.minMargin;
    }

    public void setMinMargin(double d) {
        ArgumentChecker.assertIsNonNegative("minMargin", d);
        this.minMargin = d;
    }

    public VectorFactory<?> getVectorFactory() {
        return this.vectorFactory;
    }

    public void setVectorFactory(VectorFactory<?> vectorFactory) {
        this.vectorFactory = vectorFactory;
    }

    public int getErrorCount() {
        return this.errorCount;
    }

    protected void setErrorCount(int i) {
        this.errorCount = i;
    }

    public NamedValue<Integer> getPerformance() {
        return new DefaultNamedValue("error count", Integer.valueOf(getErrorCount()));
    }
}
