package gov.sandia.cognition.learning.function.categorization;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.ValueDiscriminantPair;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.WeightedValue;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Random;
import java.util.Set;

@PublicationReference(author = {"Wikipedia"}, title = "Maximum a posteriori estimation", type = PublicationType.WebPage, year = 2010, url = "http://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation")
/* loaded from: input_file:gov/sandia/cognition/learning/function/categorization/MaximumAPosterioriCategorizer.class */
public class MaximumAPosterioriCategorizer<ObservationType, CategoryType> extends AbstractDistribution<ObservationType> implements DiscriminantCategorizer<ObservationType, CategoryType, Double> {
    DataDistribution.PMF<CategoryType> categoryPriors = new DefaultDataDistribution.PMF(2);
    Map<CategoryType, ProbabilityFunction<ObservationType>> categoryConditionals = new HashMap(2);

    /* loaded from: input_file:gov/sandia/cognition/learning/function/categorization/MaximumAPosterioriCategorizer$Learner.class */
    public static class Learner<ObservationType, CategoryType> extends AbstractCloneableSerializable implements SupervisedBatchLearner<ObservationType, CategoryType, MaximumAPosterioriCategorizer<ObservationType, CategoryType>> {
        private BatchLearner<Collection<? extends ObservationType>, ? extends ComputableDistribution<ObservationType>> conditionalLearner;

        public Learner() {
            this(null);
        }

        public Learner(BatchLearner<Collection<? extends ObservationType>, ? extends ComputableDistribution<ObservationType>> batchLearner) {
            this.conditionalLearner = batchLearner;
        }

        @Override // gov.sandia.cognition.util.AbstractCloneableSerializable
        /* renamed from: clone */
        public Learner<ObservationType, CategoryType> mo0clone() {
            Learner<ObservationType, CategoryType> learner = (Learner) super.mo0clone();
            learner.setConditionalLearner((BatchLearner) ObjectUtil.cloneSmart(getConditionalLearner()));
            return learner;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
        public MaximumAPosterioriCategorizer<ObservationType, CategoryType> learn(Collection<? extends InputOutputPair<? extends ObservationType, CategoryType>> collection) {
            DefaultDataDistribution.PMF pmf = new DefaultDataDistribution.PMF();
            HashMap hashMap = new HashMap();
            for (InputOutputPair<? extends ObservationType, CategoryType> inputOutputPair : collection) {
                pmf.increment(inputOutputPair.getOutput());
                LinkedList linkedList = (LinkedList) hashMap.get(inputOutputPair.getOutput());
                if (linkedList == null) {
                    linkedList = new LinkedList();
                    hashMap.put(inputOutputPair.getOutput(), linkedList);
                }
                linkedList.add(inputOutputPair.getInput());
            }
            MaximumAPosterioriCategorizer<ObservationType, CategoryType> maximumAPosterioriCategorizer = (MaximumAPosterioriCategorizer<ObservationType, CategoryType>) new MaximumAPosterioriCategorizer();
            for (Object obj : pmf.getDomain2()) {
                maximumAPosterioriCategorizer.addCategory(obj, pmf.get(obj), this.conditionalLearner.learn((LinkedList) hashMap.get(obj)).getProbabilityFunction());
            }
            return maximumAPosterioriCategorizer;
        }

        public BatchLearner<Collection<? extends ObservationType>, ? extends ComputableDistribution<ObservationType>> getConditionalLearner() {
            return this.conditionalLearner;
        }

        public void setConditionalLearner(BatchLearner<Collection<? extends ObservationType>, ? extends ComputableDistribution<ObservationType>> batchLearner) {
            this.conditionalLearner = batchLearner;
        }
    }

    @Override // gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public MaximumAPosterioriCategorizer<ObservationType, CategoryType> mo0clone() {
        return (MaximumAPosterioriCategorizer) super.mo0clone();
    }

    public void addCategory(CategoryType categorytype, double d, ProbabilityFunction<ObservationType> probabilityFunction) {
        this.categoryPriors.increment(categorytype, d);
        this.categoryConditionals.put(categorytype, probabilityFunction);
    }

    public WeightedValue<ProbabilityFunction<ObservationType>> getCategory(CategoryType categorytype) {
        return new DefaultWeightedValue(this.categoryConditionals.get(categorytype), this.categoryPriors.evaluate(categorytype).doubleValue());
    }

    @Override // gov.sandia.cognition.learning.function.categorization.Categorizer
    public Set<? extends CategoryType> getCategories() {
        return this.categoryConditionals.keySet();
    }

    @Override // gov.sandia.cognition.evaluator.Evaluator
    public CategoryType evaluate(ObservationType observationtype) {
        return evaluateWithDiscriminant((MaximumAPosterioriCategorizer<ObservationType, CategoryType>) observationtype).getValue();
    }

    @Override // gov.sandia.cognition.learning.function.categorization.DiscriminantCategorizer
    public DefaultWeightedValueDiscriminant<CategoryType> evaluateWithDiscriminant(ObservationType observationtype) {
        CategoryType categorytype = null;
        double d = Double.NEGATIVE_INFINITY;
        for (CategoryType categorytype2 : getCategories()) {
            double computePosterior = computePosterior(observationtype, categorytype2);
            if (d < computePosterior) {
                d = computePosterior;
                categorytype = categorytype2;
            }
        }
        return DefaultWeightedValueDiscriminant.create(categorytype, d);
    }

    public double computePosterior(ObservationType observationtype, CategoryType categorytype) {
        double d;
        ProbabilityFunction<ObservationType> probabilityFunction = this.categoryConditionals.get(categorytype);
        if (probabilityFunction != null) {
            d = probabilityFunction.evaluate(observationtype).doubleValue() * this.categoryPriors.evaluate(categorytype).doubleValue();
        } else {
            d = 0.0d;
        }
        return d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v21, types: [java.lang.Double] */
    /* JADX WARN: Type inference failed for: r0v22, types: [java.lang.Double] */
    /* JADX WARN: Type inference failed for: r0v28, types: [gov.sandia.cognition.math.Ring] */
    public ObservationType getMean() {
        ObservationType observationtype = null;
        for (CategoryType categorytype : getCategories()) {
            ObservationType mean = getMean();
            double doubleValue = this.categoryPriors.evaluate(categorytype).doubleValue();
            if (mean instanceof Number) {
                if (observationtype == null) {
                    observationtype = new Double(0.0d);
                }
                observationtype = new Double(((Number) observationtype).doubleValue() + (doubleValue * ((Number) mean).doubleValue()));
            } else {
                if (!(mean instanceof Ring)) {
                    throw new UnsupportedOperationException("Mean not supported for type " + mean.getClass());
                }
                ?? scale = ((Ring) mean).scale(doubleValue);
                if (observationtype == null) {
                    observationtype = scale;
                } else {
                    ((Ring) observationtype).plusEquals(scale);
                }
            }
        }
        return observationtype;
    }

    @Override // gov.sandia.cognition.statistics.Distribution
    public void sampleInto(Random random, int i, Collection<? super ObservationType> collection) {
        Iterator<CategoryType> it = this.categoryPriors.sample(random, i).iterator();
        while (it.hasNext()) {
            collection.add(this.categoryConditionals.get(it.next()).sample(random));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.function.categorization.DiscriminantCategorizer
    public /* bridge */ /* synthetic */ ValueDiscriminantPair evaluateWithDiscriminant(Object obj) {
        return evaluateWithDiscriminant((MaximumAPosterioriCategorizer<ObservationType, CategoryType>) obj);
    }
}
