package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.factory.Factory;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.BatchLearnerContainer;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;

@PublicationReference(author = {"Leo Breiman"}, title = "Pasting small votes for classification in large databases and on-line", year = 1999, type = PublicationType.Journal, publication = "Machine Learning", pages = {85, 103}, url = "http://www.springerlink.com/content/mnu2r28218651707/fulltext.pdf")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/ensemble/IVotingCategorizerLearner.class */
public class IVotingCategorizerLearner<InputType, CategoryType> extends AbstractAnytimeSupervisedBatchLearner<InputType, CategoryType, WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>>> implements Randomized, BatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>>>, BagBasedCategorizerEnsembleLearner<InputType, CategoryType> {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_PERCENT_TO_SAMPLE = 0.1d;
    public static final double DEFAULT_PROPORTION_INCORRECT_IN_SAMPLE = 0.5d;
    public static final boolean DEFAULT_VOTE_OUT_OF_BAG_ONLY = true;
    protected BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner;
    protected double percentToSample;
    protected double proportionIncorrectInSample;
    protected boolean voteOutOfBagOnly;
    protected Factory<? extends DataDistribution<CategoryType>> counterFactory;
    protected Random random;
    protected transient WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>> ensemble;
    protected transient ArrayList<? extends InputOutputPair<? extends InputType, CategoryType>> dataList;
    protected transient ArrayList<DataDistribution<CategoryType>> dataFullEstimates;
    protected transient ArrayList<DataDistribution<CategoryType>> dataOutOfBagEstimates;
    protected transient boolean[] currentEnsembleCorrect;
    protected transient ArrayList<Integer> currentCorrectIndices;
    protected transient ArrayList<Integer> currentIncorrectIndices;
    protected transient int sampleSize;
    protected transient int numCorrectToSample;
    protected transient int numIncorrectToSample;
    protected transient ArrayList<InputOutputPair<? extends InputType, CategoryType>> currentBag;
    protected transient int[] dataInBag;
    protected transient Evaluator<? super InputType, ? extends CategoryType> currentMember;
    protected transient ArrayList<CategoryType> currentMemberEstimates;

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/ensemble/IVotingCategorizerLearner$OutOfBagErrorStoppingCriteria.class */
    public static class OutOfBagErrorStoppingCriteria<InputType, CategoryType> extends AbstractCategorizerOutOfBagStoppingCriteria<InputType, CategoryType> {
        protected transient IVotingCategorizerLearner<InputType, CategoryType> learner;

        public OutOfBagErrorStoppingCriteria() {
            this(25);
        }

        public OutOfBagErrorStoppingCriteria(int i) {
            super(i);
        }

        @Override // gov.sandia.cognition.learning.algorithm.ensemble.AbstractCategorizerOutOfBagStoppingCriteria, gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener, gov.sandia.cognition.algorithm.IterativeAlgorithmListener
        public void algorithmStarted(IterativeAlgorithm iterativeAlgorithm) {
            this.learner = (IVotingCategorizerLearner) iterativeAlgorithm;
            super.algorithmStarted(iterativeAlgorithm);
        }

        @Override // gov.sandia.cognition.learning.algorithm.ensemble.AbstractCategorizerOutOfBagStoppingCriteria
        public DataDistribution<CategoryType> getOutOfBagEstimate(int i) {
            return this.learner.getDataOutOfBagEstimates().get(i);
        }
    }

    public IVotingCategorizerLearner() {
        this(null, 100, 0.1d, new Random());
    }

    public IVotingCategorizerLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> batchLearner, int i, double d, Random random) {
        this(batchLearner, i, d, 0.5d, true, new DefaultDataDistribution.DefaultFactory(2), random);
    }

    public IVotingCategorizerLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> batchLearner, int i, double d, double d2, boolean z, Factory<? extends DataDistribution<CategoryType>> factory, Random random) {
        super(i);
        setLearner(batchLearner);
        setPercentToSample(d);
        setProportionIncorrectInSample(d2);
        setVoteOutOfBagOnly(z);
        setCounterFactory(factory);
        setRandom(random);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        int size = ((Collection) this.data).size();
        if (size <= 0) {
            return false;
        }
        if (this.random == null) {
            this.random = new Random();
        }
        this.ensemble = new WeightedVotingCategorizerEnsemble<>(DatasetUtil.findUniqueOutputs((Iterable) this.data));
        this.dataList = CollectionUtil.asArrayList((Iterable) this.data);
        this.dataFullEstimates = new ArrayList<>(size);
        this.dataOutOfBagEstimates = new ArrayList<>(size);
        this.currentEnsembleCorrect = new boolean[size];
        this.currentCorrectIndices = new ArrayList<>(size);
        this.currentIncorrectIndices = new ArrayList<>(size);
        for (int i = 0; i < size; i++) {
            this.dataFullEstimates.add(new DefaultDataDistribution(2));
            this.dataOutOfBagEstimates.add(this.counterFactory.create());
            this.currentIncorrectIndices.add(Integer.valueOf(i));
        }
        this.sampleSize = Math.max(1, (int) (this.percentToSample * size));
        this.numIncorrectToSample = (int) (this.proportionIncorrectInSample * this.sampleSize);
        this.numCorrectToSample = this.sampleSize - this.numIncorrectToSample;
        this.currentBag = new ArrayList<>(this.numCorrectToSample + this.numIncorrectToSample);
        this.dataInBag = new int[size];
        this.currentMember = null;
        this.currentMemberEstimates = new ArrayList<>(size);
        for (int i2 = 0; i2 < size; i2++) {
            this.currentMemberEstimates.add(null);
        }
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        int size = this.dataList.size();
        this.currentBag.clear();
        for (int i = 0; i < size; i++) {
            this.dataInBag[i] = 0;
        }
        ArrayList<Integer> arrayList = this.currentCorrectIndices;
        ArrayList<Integer> arrayList2 = this.currentIncorrectIndices;
        if (arrayList2.isEmpty()) {
            arrayList2 = arrayList;
        } else if (arrayList.isEmpty()) {
            arrayList = arrayList2;
        }
        this.currentBag.clear();
        createBag(arrayList, arrayList2);
        this.currentMember = this.learner.learn(this.currentBag);
        this.ensemble.add(this.currentMember, 1.0d);
        this.currentCorrectIndices.clear();
        this.currentIncorrectIndices.clear();
        for (int i2 = 0; i2 < size; i2++) {
            InputOutputPair<? extends InputType, CategoryType> inputOutputPair = this.dataList.get(i2);
            CategoryType output = inputOutputPair.getOutput();
            CategoryType evaluate = this.currentMember.evaluate(inputOutputPair.getInput());
            this.currentMemberEstimates.set(i2, evaluate);
            DataDistribution<CategoryType> dataDistribution = this.dataFullEstimates.get(i2);
            DataDistribution<CategoryType> dataDistribution2 = this.dataOutOfBagEstimates.get(i2);
            if (evaluate != null) {
                dataDistribution.increment(evaluate);
                if (this.dataInBag[i2] <= 0) {
                    dataDistribution2.increment(evaluate);
                }
            }
            CategoryType maxValueKey = (!this.voteOutOfBagOnly || dataDistribution2.getTotal() <= 0.0d) ? dataDistribution.getMaxValueKey() : dataDistribution2.getMaxValueKey();
            boolean z = maxValueKey == null || ObjectUtil.equalsSafe(output, maxValueKey);
            this.currentEnsembleCorrect[i2] = z;
            if (z) {
                this.currentCorrectIndices.add(Integer.valueOf(i2));
            } else {
                this.currentIncorrectIndices.add(Integer.valueOf(i2));
            }
        }
        return true;
    }

    protected void createBag(ArrayList<Integer> arrayList, ArrayList<Integer> arrayList2) {
        sampleIndicesWithReplacementInto(arrayList, this.dataList, this.numCorrectToSample, this.random, this.currentBag, this.dataInBag);
        sampleIndicesWithReplacementInto(arrayList2, this.dataList, this.numIncorrectToSample, this.random, this.currentBag, this.dataInBag);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static <DataType> void sampleIndicesWithReplacementInto(ArrayList<Integer> arrayList, ArrayList<? extends DataType> arrayList2, int i, Random random, ArrayList<DataType> arrayList3, int[] iArr) {
        int size = arrayList.size();
        for (int i2 = 0; i2 < i; i2++) {
            int intValue = arrayList.get(random.nextInt(size)).intValue();
            arrayList3.add(arrayList2.get(intValue));
            iArr[intValue] = iArr[intValue] + 1;
        }
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        this.dataList = null;
        this.dataFullEstimates = null;
        this.dataOutOfBagEstimates = null;
        this.dataInBag = null;
        this.currentMember = null;
        this.currentCorrectIndices = null;
        this.currentIncorrectIndices = null;
        this.currentBag = null;
        this.currentEnsembleCorrect = null;
        this.currentMemberEstimates = null;
    }

    @Override // gov.sandia.cognition.algorithm.AnytimeAlgorithm
    /* renamed from: getResult */
    public WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>> getResult2() {
        return this.ensemble;
    }

    @Override // gov.sandia.cognition.learning.algorithm.ensemble.BagBasedCategorizerEnsembleLearner
    public int[] getDataInBag() {
        return this.dataInBag;
    }

    @Override // gov.sandia.cognition.learning.algorithm.ensemble.BagBasedCategorizerEnsembleLearner
    public InputOutputPair<? extends InputType, CategoryType> getExample(int i) {
        return this.dataList.get(i);
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearnerContainer
    public BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> getLearner() {
        return this.learner;
    }

    public void setLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> batchLearner) {
        this.learner = batchLearner;
    }

    public double getPercentToSample() {
        return this.percentToSample;
    }

    public void setPercentToSample(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("percentToSample must be greater than zero.");
        }
        this.percentToSample = d;
    }

    public double getProportionIncorrectInSample() {
        return this.proportionIncorrectInSample;
    }

    public void setProportionIncorrectInSample(double d) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("proportionIncorrectInSample must be between 0.0 and 1.0 (inclusive).");
        }
        this.proportionIncorrectInSample = d;
    }

    public boolean isVoteOutOfBagOnly() {
        return this.voteOutOfBagOnly;
    }

    public void setVoteOutOfBagOnly(boolean z) {
        this.voteOutOfBagOnly = z;
    }

    public Factory<? extends DataDistribution<CategoryType>> getCounterFactory() {
        return this.counterFactory;
    }

    public void setCounterFactory(Factory<? extends DataDistribution<CategoryType>> factory) {
        this.counterFactory = factory;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public Random getRandom() {
        return this.random;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public void setRandom(Random random) {
        this.random = random;
    }

    public List<DataDistribution<CategoryType>> getDataFullEstimates() {
        return Collections.unmodifiableList(this.dataFullEstimates);
    }

    public List<DataDistribution<CategoryType>> getDataOutOfBagEstimates() {
        return Collections.unmodifiableList(this.dataOutOfBagEstimates);
    }

    public boolean[] getCurrentEnsembleCorrect() {
        return this.currentEnsembleCorrect;
    }
}
