package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.math.matrix.mtj.Vector2;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import java.util.ArrayList;
import java.util.Map;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/tree/VectorThresholdInformationGainLearner.class */
public class VectorThresholdInformationGainLearner<OutputType> extends AbstractVectorThresholdMaximumGainLearner<OutputType> implements PriorWeightedNodeLearner<OutputType> {
    private static final double LOG2 = Math.log(2.0d);
    private ArrayList<OutputType> klasses = null;
    private double[] klassPriors = null;
    private int[] klassCounts = null;
    private double[] klassProbs = null;

    @Override // gov.sandia.cognition.learning.algorithm.tree.AbstractVectorThresholdMaximumGainLearner
    public double computeSplitGain(DefaultDataDistribution<OutputType> defaultDataDistribution, DefaultDataDistribution<OutputType> defaultDataDistribution2, DefaultDataDistribution<OutputType> defaultDataDistribution3) {
        if (this.klassPriors == null) {
            return legacyComputSplitGain(defaultDataDistribution, defaultDataDistribution2, defaultDataDistribution3);
        }
        Vector2 weightedEntropy = weightedEntropy(defaultDataDistribution);
        Vector2 weightedEntropy2 = weightedEntropy(defaultDataDistribution2);
        Vector2 weightedEntropy3 = weightedEntropy(defaultDataDistribution3);
        return (weightedEntropy.getFirst().doubleValue() - ((weightedEntropy2.getSecond().doubleValue() / weightedEntropy.getSecond().doubleValue()) * weightedEntropy2.getFirst().doubleValue())) - ((weightedEntropy3.getSecond().doubleValue() / weightedEntropy.getSecond().doubleValue()) * weightedEntropy3.getFirst().doubleValue());
    }

    private Vector2 weightedEntropy(DefaultDataDistribution<OutputType> defaultDataDistribution) {
        double d = 0.0d;
        for (int i = 0; i < this.klassProbs.length; i++) {
            this.klassProbs[i] = (this.klassPriors[i] * defaultDataDistribution.get(this.klasses.get(i))) / this.klassCounts[i];
            d += this.klassProbs[i];
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.klassProbs.length; i2++) {
            double d3 = this.klassProbs[i2] / d;
            if (d3 > 0.0d) {
                d2 -= d3 * lb(d3);
            }
        }
        return new Vector2(d2, d);
    }

    private static double lb(double d) {
        return Math.log(d) / LOG2;
    }

    private double legacyComputSplitGain(DefaultDataDistribution<OutputType> defaultDataDistribution, DefaultDataDistribution<OutputType> defaultDataDistribution2, DefaultDataDistribution<OutputType> defaultDataDistribution3) {
        double total = defaultDataDistribution.getTotal();
        double entropy = defaultDataDistribution.getEntropy();
        double entropy2 = defaultDataDistribution2.getEntropy();
        double entropy3 = defaultDataDistribution3.getEntropy();
        return (entropy - ((defaultDataDistribution2.getTotal() / total) * entropy2)) - ((defaultDataDistribution3.getTotal() / total) * entropy3);
    }

    @Override // gov.sandia.cognition.learning.algorithm.tree.PriorWeightedNodeLearner
    public void configure(Map<OutputType, Double> map, Map<OutputType, Integer> map2) {
        this.klasses = new ArrayList<>(map2.keySet());
        this.klassCounts = new int[this.klasses.size()];
        int i = 0;
        for (int i2 = 0; i2 < this.klasses.size(); i2++) {
            this.klassCounts[i2] = map2.get(this.klasses.get(i2)).intValue();
            i += this.klassCounts[i2];
        }
        this.klassPriors = new double[this.klasses.size()];
        if (map != null) {
            for (int i3 = 0; i3 < this.klasses.size(); i3++) {
                this.klassPriors[i3] = map.get(this.klasses.get(i3)).doubleValue();
            }
        } else if (i > 0) {
            for (int i4 = 0; i4 < this.klasses.size(); i4++) {
                this.klassPriors[i4] = this.klassCounts[i4] / i;
            }
        } else {
            for (int i5 = 0; i5 < this.klasses.size(); i5++) {
                this.klassPriors[i5] = 1.0d / this.klasses.size();
            }
        }
        this.klassProbs = new double[this.klasses.size()];
    }
}
