package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/tree/VectorThresholdVarianceLearner.class */
public class VectorThresholdVarianceLearner extends AbstractCloneableSerializable implements VectorThresholdLearner<Double> {
    public static final int DEFAULT_MIN_SPLIT_SIZE = 1;
    protected int minSplitSize;
    protected int[] dimensionsToConsider;

    public VectorThresholdVarianceLearner() {
        this(1, null);
    }

    public VectorThresholdVarianceLearner(int i) {
        this(i, null);
    }

    public VectorThresholdVarianceLearner(int i, int... iArr) {
        setMinSplitSize(i);
        setDimensionsToConsider(iArr);
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public VectorElementThresholdCategorizer learn(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> collection) {
        if (collection == null || collection.size() < 2 * this.minSplitSize) {
            return null;
        }
        double computeOutputVariance = DatasetUtil.computeOutputVariance(collection);
        double d = -1.0d;
        int i = -1;
        double d2 = 0.0d;
        int inputDimensionality = this.dimensionsToConsider == null ? DatasetUtil.getInputDimensionality(collection) : this.dimensionsToConsider.length;
        for (int i2 = 0; i2 < inputDimensionality; i2++) {
            int i3 = this.dimensionsToConsider == null ? i2 : this.dimensionsToConsider[i2];
            DefaultPair<Double, Double> computeBestGainThreshold = computeBestGainThreshold(collection, i3, computeOutputVariance);
            if (computeBestGainThreshold != null) {
                double doubleValue = computeBestGainThreshold.getFirst().doubleValue();
                if (i == -1 || doubleValue > d) {
                    d = doubleValue;
                    i = i3;
                    d2 = computeBestGainThreshold.getSecond().doubleValue();
                }
            }
        }
        if (i < 0) {
            return null;
        }
        return new VectorElementThresholdCategorizer(i, d2);
    }

    public DefaultPair<Double, Double> computeBestGainThreshold(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> collection, int i, double d) {
        int size = collection.size();
        if (size < 2 * this.minSplitSize) {
            return null;
        }
        ArrayList arrayList = new ArrayList(size);
        for (InputOutputPair<? extends Vectorizable, Double> inputOutputPair : collection) {
            arrayList.add(new DefaultWeightedValue(inputOutputPair.getOutput(), Double.valueOf(inputOutputPair.getInput().convertToVector().getElement(i)).doubleValue()));
        }
        Collections.sort(arrayList, DefaultWeightedValue.WeightComparator.getInstance());
        if (((DefaultWeightedValue) arrayList.get(0)).getWeight() >= ((DefaultWeightedValue) arrayList.get(size - 1)).getWeight()) {
            return null;
        }
        UnivariateGaussian.SufficientStatistic sufficientStatistic = new UnivariateGaussian.SufficientStatistic();
        UnivariateGaussian.SufficientStatistic sufficientStatistic2 = new UnivariateGaussian.SufficientStatistic();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            sufficientStatistic.update(((Double) ((DefaultWeightedValue) it.next()).getValue()).doubleValue());
        }
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        int i2 = size - this.minSplitSize;
        boolean z = false;
        for (int i3 = 0; i3 <= i2; i3++) {
            DefaultWeightedValue defaultWeightedValue = (DefaultWeightedValue) arrayList.get(i3);
            double weight = defaultWeightedValue.getWeight();
            double doubleValue = ((Double) defaultWeightedValue.getValue()).doubleValue();
            if (i3 < this.minSplitSize) {
                d4 = weight;
            } else if (weight != d5) {
                int i4 = size - i3;
                double sampleVariance = sufficientStatistic2.getSampleVariance();
                double sampleVariance2 = sufficientStatistic.getSampleVariance();
                double d6 = i4 / size;
                double d7 = i3 / size;
                double d8 = (d - (d6 * sampleVariance2)) - (d7 * sampleVariance);
                if (d8 >= d2) {
                    double abs = 1.0d - Math.abs(d6 - d7);
                    if (d8 > d2 || abs > d3) {
                        d2 = d8;
                        d3 = abs;
                        d4 = (weight + d5) / 2.0d;
                        z = true;
                    }
                }
            }
            sufficientStatistic.remove(doubleValue);
            sufficientStatistic2.update(doubleValue);
            d5 = weight;
        }
        if (z) {
            return new DefaultPair<>(Double.valueOf(d2), Double.valueOf(d4));
        }
        return null;
    }

    @Override // gov.sandia.cognition.learning.algorithm.DimensionFilterableLearner
    public int[] getDimensionsToConsider() {
        return this.dimensionsToConsider;
    }

    @Override // gov.sandia.cognition.learning.algorithm.DimensionFilterableLearner
    public void setDimensionsToConsider(int... iArr) {
        this.dimensionsToConsider = iArr;
    }

    public int getMinSplitSize() {
        return this.minSplitSize;
    }

    public void setMinSplitSize(int i) {
        ArgumentChecker.assertIsPositive("minSplitSize", i);
        this.minSplitSize = i;
    }
}
