package gov.sandia.cognition.learning.function.categorization;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.ValueDiscriminantPair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Set;

/* loaded from: input_file:gov/sandia/cognition/learning/function/categorization/WinnerTakeAllCategorizer.class */
public class WinnerTakeAllCategorizer<InputType, CategoryType> extends AbstractDiscriminantCategorizer<InputType, CategoryType, Double> {
    protected Evaluator<? super InputType, ? extends Vectorizable> evaluator;

    /* loaded from: input_file:gov/sandia/cognition/learning/function/categorization/WinnerTakeAllCategorizer$Learner.class */
    public static class Learner<InputType, CategoryType> extends AbstractBatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Vector>>, ? extends Evaluator<? super InputType, ? extends Vectorizable>>> implements SupervisedBatchLearner<InputType, CategoryType, WinnerTakeAllCategorizer<InputType, CategoryType>>, VectorFactoryContainer {
        protected VectorFactory<?> vectorFactory;

        public Learner() {
            this(null);
        }

        public Learner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, Vector>>, Evaluator<? super InputType, ? extends Vectorizable>> batchLearner) {
            super(batchLearner);
            setVectorFactory(VectorFactory.getDefault());
        }

        @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
        public WinnerTakeAllCategorizer<InputType, CategoryType> learn(Collection<? extends InputOutputPair<? extends InputType, CategoryType>> collection) {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            Iterator<? extends InputOutputPair<? extends InputType, CategoryType>> it = collection.iterator();
            while (it.hasNext()) {
                CategoryType output = it.next().getOutput();
                if (!linkedHashMap.containsKey(output)) {
                    linkedHashMap.put(output, Integer.valueOf(linkedHashMap.size()));
                }
            }
            int size = linkedHashMap.size();
            ArrayList arrayList = new ArrayList(collection.size());
            for (InputOutputPair<? extends InputType, CategoryType> inputOutputPair : collection) {
                int intValue = ((Integer) linkedHashMap.get(inputOutputPair.getOutput())).intValue();
                Vector createVector = getVectorFactory().createVector(size, -1.0d);
                createVector.setElement(intValue, 1.0d);
                arrayList.add(new DefaultInputOutputPair(inputOutputPair.getInput(), createVector));
            }
            return new WinnerTakeAllCategorizer<>(getLearner().learn(arrayList), new LinkedHashSet(linkedHashMap.keySet()));
        }

        public VectorFactory<?> getVectorFactory() {
            return this.vectorFactory;
        }

        public void setVectorFactory(VectorFactory<?> vectorFactory) {
            this.vectorFactory = vectorFactory;
        }
    }

    public WinnerTakeAllCategorizer() {
        this(null, new LinkedHashSet());
    }

    public WinnerTakeAllCategorizer(Evaluator<? super InputType, ? extends Vectorizable> evaluator, Set<CategoryType> set) {
        super(set);
        setEvaluator(evaluator);
    }

    @Override // gov.sandia.cognition.learning.function.categorization.DiscriminantCategorizer
    public DefaultWeightedValueDiscriminant<CategoryType> evaluateWithDiscriminant(InputType inputtype) {
        return findBestCategory(((Vectorizable) this.evaluator.evaluate(inputtype)).convertToVector());
    }

    public DefaultWeightedValueDiscriminant<CategoryType> findBestCategory(Vector vector) {
        vector.assertDimensionalityEquals(this.categories.size());
        CategoryType categorytype = null;
        double d = Double.NEGATIVE_INFINITY;
        int i = 0;
        for (CategoryType categorytype2 : this.categories) {
            double element = vector.getElement(i);
            if (categorytype == null || element > d) {
                categorytype = categorytype2;
                d = element;
            }
            i++;
        }
        return new DefaultWeightedValueDiscriminant<>(categorytype, d);
    }

    public Evaluator<? super InputType, ? extends Vectorizable> getEvaluator() {
        return this.evaluator;
    }

    public void setEvaluator(Evaluator<? super InputType, ? extends Vectorizable> evaluator) {
        this.evaluator = evaluator;
    }

    @Override // gov.sandia.cognition.learning.function.categorization.AbstractCategorizer
    public void setCategories(Set<CategoryType> set) {
        super.setCategories(set);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.function.categorization.DiscriminantCategorizer
    public /* bridge */ /* synthetic */ ValueDiscriminantPair evaluateWithDiscriminant(Object obj) {
        return evaluateWithDiscriminant((WinnerTakeAllCategorizer<InputType, CategoryType>) obj);
    }
}
