package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.LinearMultiCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Iterator;
import java.util.LinkedList;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/perceptron/OnlineMultiPerceptron.class */
public class OnlineMultiPerceptron<CategoryType> extends AbstractBatchAndIncrementalLearner<InputOutputPair<? extends Vectorizable, CategoryType>, LinearMultiCategorizer<CategoryType>> implements VectorFactoryContainer {
    public static final double DEFAULT_MIN_MARGIN = 0.0d;
    protected double minMargin;
    protected VectorFactory<?> vectorFactory;

    @PublicationReference(title = "Ultraconservative online algorithms for multiclass problems", author = {"Koby Crammer", "Yoram Singer"}, year = 2003, type = PublicationType.Journal, publication = "The Journal of Machine Learning Research", pages = {951, 991}, url = "http://portal.acm.org/citation.cfm?id=944936")
    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/perceptron/OnlineMultiPerceptron$ProportionalUpdate.class */
    public static class ProportionalUpdate<CategoryType> extends OnlineMultiPerceptron<CategoryType> {
        public static final double DEFAULT_MIN_MARGIN = 0.001d;

        public ProportionalUpdate() {
            this(0.001d);
        }

        public ProportionalUpdate(double d) {
            super(d);
        }

        public ProportionalUpdate(double d, VectorFactory<?> vectorFactory) {
            super(d, vectorFactory);
        }

        /* JADX WARN: Type inference failed for: r0v63, types: [gov.sandia.cognition.math.matrix.Vector] */
        @Override // gov.sandia.cognition.learning.algorithm.perceptron.OnlineMultiPerceptron, gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public void update(LinearMultiCategorizer<CategoryType> linearMultiCategorizer, InputOutputPair<? extends Vectorizable, CategoryType> inputOutputPair) {
            Vector convertToVector = inputOutputPair.getInput().convertToVector();
            CategoryType output = inputOutputPair.getOutput();
            if (!linearMultiCategorizer.getCategories().contains(output)) {
                linearMultiCategorizer.getPrototypes().put(output, new LinearBinaryCategorizer(getVectorFactory().createVector(convertToVector.getDimensionality()), 0.0d));
            }
            double evaluateAsDouble = linearMultiCategorizer.evaluateAsDouble(convertToVector, (Vector) output) - this.minMargin;
            LinkedList linkedList = new LinkedList();
            double d = 0.0d;
            for (CategoryType categorytype : linearMultiCategorizer.getCategories()) {
                double evaluateAsDouble2 = linearMultiCategorizer.evaluateAsDouble(convertToVector, (Vector) categorytype) - evaluateAsDouble;
                if (evaluateAsDouble2 >= 0.0d && !output.equals(categorytype)) {
                    linkedList.add(DefaultWeightedValue.create(categorytype, evaluateAsDouble2));
                    d += evaluateAsDouble2;
                }
            }
            if (linkedList.isEmpty()) {
                return;
            }
            LinearBinaryCategorizer linearBinaryCategorizer = linearMultiCategorizer.getPrototypes().get(output);
            linearBinaryCategorizer.getWeights().plusEquals(convertToVector);
            linearBinaryCategorizer.setBias(linearBinaryCategorizer.getBias() + 1.0d);
            Iterator it = linkedList.iterator();
            while (it.hasNext()) {
                DefaultWeightedValue defaultWeightedValue = (DefaultWeightedValue) it.next();
                LinearBinaryCategorizer linearBinaryCategorizer2 = linearMultiCategorizer.getPrototypes().get(defaultWeightedValue.getValue());
                double weight = defaultWeightedValue.getWeight() / d;
                linearBinaryCategorizer2.getWeights().minusEquals(convertToVector.scale(weight));
                linearBinaryCategorizer2.setBias(linearBinaryCategorizer2.getBias() - weight);
            }
        }

        @Override // gov.sandia.cognition.learning.algorithm.perceptron.OnlineMultiPerceptron
        public void setMinMargin(double d) {
            ArgumentChecker.assertIsPositive("minMargin", d);
            super.setMinMargin(d);
        }

        @Override // gov.sandia.cognition.learning.algorithm.perceptron.OnlineMultiPerceptron, gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public /* bridge */ /* synthetic */ Object createInitialLearnedObject() {
            return super.createInitialLearnedObject();
        }
    }

    @PublicationReference(title = "Ultraconservative online algorithms for multiclass problems", author = {"Koby Crammer", "Yoram Singer"}, year = 2003, type = PublicationType.Journal, publication = "The Journal of Machine Learning Research", pages = {951, 991}, url = "http://portal.acm.org/citation.cfm?id=944936")
    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/perceptron/OnlineMultiPerceptron$UniformUpdate.class */
    public static class UniformUpdate<CategoryType> extends OnlineMultiPerceptron<CategoryType> {
        public UniformUpdate() {
        }

        public UniformUpdate(double d) {
            super(d);
        }

        public UniformUpdate(double d, VectorFactory<?> vectorFactory) {
            super(d, vectorFactory);
        }

        /* JADX WARN: Type inference failed for: r0v59, types: [gov.sandia.cognition.math.matrix.Vector] */
        @Override // gov.sandia.cognition.learning.algorithm.perceptron.OnlineMultiPerceptron, gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public void update(LinearMultiCategorizer<CategoryType> linearMultiCategorizer, InputOutputPair<? extends Vectorizable, CategoryType> inputOutputPair) {
            Vector convertToVector = inputOutputPair.getInput().convertToVector();
            CategoryType output = inputOutputPair.getOutput();
            if (!linearMultiCategorizer.getCategories().contains(output)) {
                linearMultiCategorizer.getPrototypes().put(output, new LinearBinaryCategorizer(getVectorFactory().createVector(convertToVector.getDimensionality()), 0.0d));
            }
            double evaluateAsDouble = linearMultiCategorizer.evaluateAsDouble(convertToVector, (Vector) output) - this.minMargin;
            LinkedList linkedList = new LinkedList();
            for (CategoryType categorytype : linearMultiCategorizer.getCategories()) {
                double evaluateAsDouble2 = linearMultiCategorizer.evaluateAsDouble(convertToVector, (Vector) categorytype);
                if (!output.equals(categorytype) && evaluateAsDouble2 >= evaluateAsDouble) {
                    linkedList.add(categorytype);
                }
            }
            if (linkedList.isEmpty()) {
                return;
            }
            LinearBinaryCategorizer linearBinaryCategorizer = linearMultiCategorizer.getPrototypes().get(output);
            linearBinaryCategorizer.getWeights().plusEquals(convertToVector);
            linearBinaryCategorizer.setBias(linearBinaryCategorizer.getBias() + 1.0d);
            double size = 1.0d / linkedList.size();
            Vector vector = (Vector) convertToVector.scale(size);
            Iterator it = linkedList.iterator();
            while (it.hasNext()) {
                LinearBinaryCategorizer linearBinaryCategorizer2 = linearMultiCategorizer.getPrototypes().get(it.next());
                linearBinaryCategorizer2.getWeights().minusEquals(vector);
                linearBinaryCategorizer2.setBias(linearBinaryCategorizer2.getBias() - size);
            }
        }

        @Override // gov.sandia.cognition.learning.algorithm.perceptron.OnlineMultiPerceptron, gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public /* bridge */ /* synthetic */ Object createInitialLearnedObject() {
            return super.createInitialLearnedObject();
        }
    }

    public OnlineMultiPerceptron() {
        this(0.0d);
    }

    public OnlineMultiPerceptron(double d) {
        this(d, VectorFactory.getDefault());
    }

    public OnlineMultiPerceptron(double d, VectorFactory<?> vectorFactory) {
        setMinMargin(d);
        setVectorFactory(vectorFactory);
    }

    @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public LinearMultiCategorizer<CategoryType> createInitialLearnedObject() {
        return new LinearMultiCategorizer<>();
    }

    /* JADX WARN: Type inference failed for: r0v48, types: [gov.sandia.cognition.math.matrix.Vector] */
    @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public void update(LinearMultiCategorizer<CategoryType> linearMultiCategorizer, InputOutputPair<? extends Vectorizable, CategoryType> inputOutputPair) {
        Vector convertToVector = inputOutputPair.getInput().convertToVector();
        CategoryType output = inputOutputPair.getOutput();
        if (!linearMultiCategorizer.getCategories().contains(output)) {
            linearMultiCategorizer.getPrototypes().put(output, new LinearBinaryCategorizer(getVectorFactory().createVector(convertToVector.getDimensionality()), 0.0d));
        }
        CategoryType categorytype = null;
        double d = Double.NEGATIVE_INFINITY;
        for (CategoryType categorytype2 : linearMultiCategorizer.getCategories()) {
            double evaluateAsDouble = linearMultiCategorizer.evaluateAsDouble(convertToVector, (Vector) categorytype2);
            if (output.equals(categorytype2)) {
                evaluateAsDouble -= this.minMargin;
            }
            if (evaluateAsDouble > d) {
                categorytype = categorytype2;
                d = evaluateAsDouble;
            }
        }
        if (ObjectUtil.equalsSafe(output, categorytype)) {
            return;
        }
        LinearBinaryCategorizer linearBinaryCategorizer = linearMultiCategorizer.getPrototypes().get(output);
        linearBinaryCategorizer.getWeights().plusEquals(convertToVector);
        linearBinaryCategorizer.setBias(linearBinaryCategorizer.getBias() + 1.0d);
        LinearBinaryCategorizer linearBinaryCategorizer2 = linearMultiCategorizer.getPrototypes().get(categorytype);
        linearBinaryCategorizer2.getWeights().minusEquals(convertToVector);
        linearBinaryCategorizer2.setBias(linearBinaryCategorizer2.getBias() - 1.0d);
    }

    public double getMinMargin() {
        return this.minMargin;
    }

    public void setMinMargin(double d) {
        ArgumentChecker.assertIsNonNegative("minMargin", d);
        this.minMargin = d;
    }

    @Override // gov.sandia.cognition.math.matrix.VectorFactoryContainer
    public VectorFactory<?> getVectorFactory() {
        return this.vectorFactory;
    }

    public void setVectorFactory(VectorFactory<?> vectorFactory) {
        this.vectorFactory = vectorFactory;
    }
}
