package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.collection.FiniteCapacityBuffer;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/ensemble/AbstractCategorizerOutOfBagStoppingCriteria.class */
public abstract class AbstractCategorizerOutOfBagStoppingCriteria<InputType, CategoryType> extends AbstractIterativeAlgorithmListener {
    public static final int DEFAULT_SMOOTHING_WINDOW_SIZE = 25;
    protected int smoothingWindowSize;
    protected transient BagBasedCategorizerEnsembleLearner<InputType, CategoryType> learner;
    protected transient boolean[] outOfBagCorrect;
    protected transient int outOfBagErrorCount;
    protected transient ArrayList<Double> rawErrorRates;
    protected transient ArrayList<Double> smoothedErrorRates;
    protected transient FiniteCapacityBuffer<Double> smoothingBuffer;
    protected transient double previousSmoothedErrorRate;

    public AbstractCategorizerOutOfBagStoppingCriteria() {
        this(25);
    }

    public AbstractCategorizerOutOfBagStoppingCriteria(int i) {
        setSmoothingWindowSize(i);
    }

    @Override // gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener, gov.sandia.cognition.algorithm.IterativeAlgorithmListener
    public void algorithmStarted(IterativeAlgorithm iterativeAlgorithm) {
        this.learner = (BagBasedCategorizerEnsembleLearner) iterativeAlgorithm;
        int size = this.learner.getData().size();
        this.outOfBagCorrect = new boolean[size];
        this.outOfBagErrorCount = size;
        this.rawErrorRates = new ArrayList<>();
        this.smoothedErrorRates = new ArrayList<>();
        this.smoothingBuffer = new FiniteCapacityBuffer<>(this.smoothingWindowSize);
        this.previousSmoothedErrorRate = Double.MAX_VALUE;
    }

    @Override // gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener, gov.sandia.cognition.algorithm.IterativeAlgorithmListener
    public void algorithmEnded(IterativeAlgorithm iterativeAlgorithm) {
        this.learner = null;
        this.outOfBagCorrect = null;
        this.rawErrorRates = null;
        this.smoothedErrorRates = null;
        this.smoothingBuffer = null;
    }

    public abstract DataDistribution<CategoryType> getOutOfBagEstimate(int i);

    @Override // gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener, gov.sandia.cognition.algorithm.IterativeAlgorithmListener
    public void stepEnded(IterativeAlgorithm iterativeAlgorithm) {
        int size = this.learner.getData().size();
        int[] dataInBag = this.learner.getDataInBag();
        for (int i = 0; i < size; i++) {
            if (dataInBag[i] <= 0) {
                CategoryType output = this.learner.getExample(i).getOutput();
                CategoryType maxValueKey = getOutOfBagEstimate(i).getMaxValueKey();
                boolean z = this.outOfBagCorrect[i];
                boolean equalsSafe = ObjectUtil.equalsSafe(output, maxValueKey);
                if (z != equalsSafe) {
                    this.outOfBagCorrect[i] = equalsSafe;
                    if (equalsSafe) {
                        this.outOfBagErrorCount--;
                    } else {
                        this.outOfBagErrorCount++;
                    }
                }
            }
        }
        double size2 = this.outOfBagErrorCount / this.learner.getData().size();
        this.rawErrorRates.add(Double.valueOf(size2));
        this.smoothingBuffer.add(Double.valueOf(size2));
        double computeMean = UnivariateStatisticsUtil.computeMean(this.smoothingBuffer);
        this.smoothedErrorRates.add(Double.valueOf(computeMean));
        if (computeMean >= this.previousSmoothedErrorRate) {
            this.learner.stop();
            int size3 = this.rawErrorRates.size();
            int i2 = 0;
            double d = Double.MAX_VALUE;
            for (int i3 = 0; i3 < this.smoothingBuffer.size(); i3++) {
                int i4 = (size3 - i3) - 1;
                double doubleValue = this.rawErrorRates.get(i4).doubleValue();
                if (doubleValue <= d) {
                    i2 = i4;
                    d = doubleValue;
                }
            }
            for (int i5 = size3 - 1; i5 > i2; i5--) {
                this.learner.getResult2().members.remove(i5);
            }
        }
        this.previousSmoothedErrorRate = computeMean;
    }

    public int getSmoothingWindowSize() {
        return this.smoothingWindowSize;
    }

    public void setSmoothingWindowSize(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("smoothingWindowSize must be positive.");
        }
        this.smoothingWindowSize = i;
    }
}
