package gov.sandia.cognition.learning.algorithm.bayes;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.DiscriminantCategorizer;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

@PublicationReferences(references = {@PublicationReference(author = {"Richard O. Duda", "Peter E. Hart", "David G. Stork"}, title = "Pattern Classification: Second Edition", type = PublicationType.Book, year = 2001, pages = {56, 62}), @PublicationReference(author = {"Wikipedia"}, title = "Naive Bayes classifier", type = PublicationType.WebPage, year = 2009, url = "http://en.wikipedia.org/wiki/Naive_bayes")})
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/bayes/DiscreteNaiveBayesCategorizer.class */
public class DiscreteNaiveBayesCategorizer<InputType, CategoryType> extends AbstractCloneableSerializable implements DiscriminantCategorizer<Collection<InputType>, CategoryType, Double> {
    private Map<CategoryType, List<DefaultDataDistribution<InputType>>> conditionalProbabilities;
    private DefaultDataDistribution<CategoryType> priorProbabilities;
    private int inputDimensionality;

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/bayes/DiscreteNaiveBayesCategorizer$Learner.class */
    public static class Learner<InputType, CategoryType> extends AbstractCloneableSerializable implements SupervisedBatchLearner<Collection<InputType>, CategoryType, DiscreteNaiveBayesCategorizer<InputType, CategoryType>> {
        @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
        public DiscreteNaiveBayesCategorizer<InputType, CategoryType> learn(Collection<? extends InputOutputPair<? extends Collection<InputType>, CategoryType>> collection) {
            DiscreteNaiveBayesCategorizer<InputType, CategoryType> discreteNaiveBayesCategorizer = new DiscreteNaiveBayesCategorizer<>();
            for (InputOutputPair<? extends Collection<InputType>, CategoryType> inputOutputPair : collection) {
                discreteNaiveBayesCategorizer.update(inputOutputPair.getInput(), inputOutputPair.getOutput());
            }
            return discreteNaiveBayesCategorizer;
        }
    }

    public DiscreteNaiveBayesCategorizer() {
        this(0);
    }

    public DiscreteNaiveBayesCategorizer(int i) {
        setInputDimensionality(i);
    }

    protected DiscreteNaiveBayesCategorizer(int i, DefaultDataDistribution<CategoryType> defaultDataDistribution, Map<CategoryType, List<DefaultDataDistribution<InputType>>> map) {
        setInputDimensionality(i);
        this.priorProbabilities = defaultDataDistribution;
        this.conditionalProbabilities = map;
    }

    @Override // gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public DiscreteNaiveBayesCategorizer<InputType, CategoryType> mo0clone() {
        DiscreteNaiveBayesCategorizer<InputType, CategoryType> discreteNaiveBayesCategorizer = (DiscreteNaiveBayesCategorizer) super.mo0clone();
        discreteNaiveBayesCategorizer.conditionalProbabilities = new LinkedHashMap();
        for (CategoryType categorytype : getCategories()) {
            discreteNaiveBayesCategorizer.conditionalProbabilities.put(categorytype, ObjectUtil.cloneSmartElementsAsArrayList(this.conditionalProbabilities.get(categorytype)));
        }
        discreteNaiveBayesCategorizer.priorProbabilities = (DefaultDataDistribution) ObjectUtil.cloneSafe(this.priorProbabilities);
        return discreteNaiveBayesCategorizer;
    }

    @Override // gov.sandia.cognition.learning.function.categorization.Categorizer
    public Set<CategoryType> getCategories() {
        return this.priorProbabilities.getDomain();
    }

    public double computeEvidenceProbabilty(Collection<InputType> collection) {
        double d = 0.0d;
        Iterator<CategoryType> it = getCategories().iterator();
        while (it.hasNext()) {
            d += computeConjuctiveProbability(collection, it.next());
        }
        return d;
    }

    public double computePosterior(Collection<InputType> collection, CategoryType categorytype) {
        double computeEvidenceProbabilty = computeEvidenceProbabilty(collection);
        if (computeEvidenceProbabilty > 0.0d) {
            return computeConjuctiveProbability(collection, categorytype) / computeEvidenceProbabilty;
        }
        return 0.0d;
    }

    public double computeConditionalProbability(Collection<InputType> collection, CategoryType categorytype) {
        if (collection.size() != getInputDimensionality()) {
            throw new IllegalArgumentException("Input dimensionality doesn't match " + getInputDimensionality());
        }
        Iterator<DefaultDataDistribution<InputType>> it = this.conditionalProbabilities.get(categorytype).iterator();
        double d = 1.0d;
        for (InputType inputtype : collection) {
            DefaultDataDistribution<InputType> next = it.next();
            if (inputtype != null) {
                d *= next.getFraction(inputtype);
            }
            if (d <= 0.0d) {
                break;
            }
        }
        return d;
    }

    public void update(Collection<InputType> collection, CategoryType categorytype) {
        if (getInputDimensionality() <= 0) {
            setInputDimensionality(collection.size());
        }
        if (collection.size() != getInputDimensionality()) {
            throw new IllegalArgumentException("Input dimensionality doesn't match " + getInputDimensionality());
        }
        if (!getCategories().contains(categorytype)) {
            ArrayList arrayList = new ArrayList(getInputDimensionality());
            for (int i = 0; i < getInputDimensionality(); i++) {
                arrayList.add(new DefaultDataDistribution());
            }
            this.conditionalProbabilities.put(categorytype, arrayList);
        }
        this.priorProbabilities.increment(categorytype);
        Iterator<DefaultDataDistribution<InputType>> it = this.conditionalProbabilities.get(categorytype).iterator();
        for (InputType inputtype : collection) {
            DefaultDataDistribution<InputType> next = it.next();
            if (inputtype != null) {
                next.increment(inputtype);
            }
        }
    }

    public double computeConjuctiveProbability(Collection<InputType> collection, CategoryType categorytype) {
        double priorProbability = getPriorProbability(categorytype);
        if (priorProbability > 0.0d) {
            return computeConditionalProbability(collection, categorytype) * priorProbability;
        }
        return 0.0d;
    }

    @Override // gov.sandia.cognition.evaluator.Evaluator
    public CategoryType evaluate(Collection<InputType> collection) {
        return evaluateWithDiscriminant((Collection) collection).getValue();
    }

    @Override // gov.sandia.cognition.learning.function.categorization.DiscriminantCategorizer
    public DefaultWeightedValueDiscriminant<CategoryType> evaluateWithDiscriminant(Collection<InputType> collection) {
        double d = -1.0d;
        CategoryType categorytype = null;
        for (CategoryType categorytype2 : getCategories()) {
            double computeConjuctiveProbability = computeConjuctiveProbability(collection, categorytype2);
            if (d < computeConjuctiveProbability) {
                d = computeConjuctiveProbability;
                categorytype = categorytype2;
            }
        }
        return DefaultWeightedValueDiscriminant.create(categorytype, d);
    }

    public double getConditionalProbability(int i, InputType inputtype, CategoryType categorytype) {
        return this.conditionalProbabilities.get(categorytype).get(i).getFraction(inputtype);
    }

    public double getPriorProbability(CategoryType categorytype) {
        return this.priorProbabilities.getFraction(categorytype);
    }

    public int getInputDimensionality() {
        return this.inputDimensionality;
    }

    public void setInputDimensionality(int i) {
        this.conditionalProbabilities = new LinkedHashMap();
        this.priorProbabilities = new DefaultDataDistribution<>();
        this.inputDimensionality = i;
    }
}
