package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.Categorizer;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/tree/CategorizationTreeLearner.class */
public class CategorizationTreeLearner<InputType, OutputType> extends AbstractDecisionTreeLearner<InputType, OutputType> implements SupervisedBatchLearner<InputType, OutputType, CategorizationTree<InputType, OutputType>> {
    public static final int DEFAULT_LEAF_COUNT_THRESHOLD = 1;
    public static final int DEFAULT_MAX_DEPTH = -1;
    protected int leafCountThreshold;
    protected int maxDepth;
    protected Map<OutputType, Double> priors;
    protected Map<OutputType, Integer> trainCounts;

    public CategorizationTreeLearner() {
        this(null);
    }

    public CategorizationTreeLearner(DeciderLearner<? super InputType, OutputType, ?, ?> deciderLearner) {
        this(deciderLearner, 1, -1, null);
    }

    public CategorizationTreeLearner(DeciderLearner<? super InputType, OutputType, ?, ?> deciderLearner, int i, int i2) {
        this(deciderLearner, i, i2, null);
    }

    public CategorizationTreeLearner(DeciderLearner<? super InputType, OutputType, ?, ?> deciderLearner, int i, int i2, Map<OutputType, Double> map) {
        super(deciderLearner);
        setLeafCountThreshold(i);
        setMaxDepth(i2);
        setCategoryPriors(map);
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public CategorizationTree<InputType, OutputType> learn(Collection<? extends InputOutputPair<? extends InputType, OutputType>> collection) {
        if (collection == null) {
            return null;
        }
        DefaultDataDistribution outputCounts = getOutputCounts(collection);
        this.trainCounts = new HashMap();
        for (Object obj : outputCounts.getDomain()) {
            this.trainCounts.put(obj, new Integer((int) outputCounts.get(obj)));
        }
        if (this.deciderLearner instanceof PriorWeightedNodeLearner) {
            ((PriorWeightedNodeLearner) this.deciderLearner).configure(this.priors, this.trainCounts);
        }
        CategorizationTree<InputType, OutputType> categorizationTree = new CategorizationTree<>(learnNode((Collection) collection, (AbstractDecisionTreeNode) null), new HashSet(outputCounts.getDomain()));
        this.trainCounts = null;
        return categorizationTree;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.learning.algorithm.tree.AbstractDecisionTreeLearner
    public CategorizationTreeNode<InputType, OutputType, ?> learnNode(Collection<? extends InputOutputPair<? extends InputType, OutputType>> collection, AbstractDecisionTreeNode<InputType, OutputType, ?> abstractDecisionTreeNode) {
        Categorizer<? super InputType, ? extends DecisionType> categorizer;
        if (collection == null || collection.size() <= 0) {
            return null;
        }
        CategorizationTreeNode<InputType, OutputType, ?> categorizationTreeNode = new CategorizationTreeNode<>(abstractDecisionTreeNode, computeMaxProbPrediction(collection));
        if (!(areAllOutputsEqual(collection) || collection.size() <= this.leafCountThreshold || (this.maxDepth > 0 && categorizationTreeNode.getDepth() >= this.maxDepth)) && (categorizer = (Categorizer) getDeciderLearner().learn(collection)) != 0) {
            categorizationTreeNode.setDecider(categorizer);
            super.learnChildNodes(categorizationTreeNode, collection, categorizer);
        }
        return categorizationTreeNode;
    }

    public static <OutputType> DefaultDataDistribution<OutputType> getOutputCounts(Collection<? extends InputOutputPair<?, OutputType>> collection) {
        DefaultDataDistribution<OutputType> defaultDataDistribution = new DefaultDataDistribution<>();
        if (collection == null) {
            return defaultDataDistribution;
        }
        Iterator<? extends InputOutputPair<?, OutputType>> it = collection.iterator();
        while (it.hasNext()) {
            defaultDataDistribution.increment(it.next().getOutput());
        }
        return defaultDataDistribution;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private OutputType computeMaxProbPrediction(Collection<? extends InputOutputPair<?, OutputType>> collection) {
        DefaultDataDistribution outputCounts = getOutputCounts(collection);
        if (this.priors == null) {
            return (OutputType) outputCounts.getMaxValueKey();
        }
        double d = -1.0d;
        OutputType outputtype = null;
        for (Object obj : outputCounts.getDomain()) {
            double doubleValue = this.priors.get(obj).doubleValue() * (outputCounts.get(obj) / this.trainCounts.get(obj).intValue());
            if (doubleValue > d) {
                d = doubleValue;
                outputtype = obj;
            }
        }
        return outputtype;
    }

    public int getLeafCountThreshold() {
        return this.leafCountThreshold;
    }

    public void setLeafCountThreshold(int i) {
        ArgumentChecker.assertIsNonNegative("leafCountThreshold", i);
        this.leafCountThreshold = i;
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public void setMaxDepth(int i) {
        this.maxDepth = i;
    }

    public void setCategoryPriors(Map<OutputType, Double> map) {
        if (map == null) {
            this.priors = null;
        } else {
            this.priors = new HashMap(map);
        }
    }
}
