package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/ensemble/CategoryBalancedBaggingLearner.class */
public class CategoryBalancedBaggingLearner<InputType, CategoryType> extends BaggingCategorizerLearner<InputType, CategoryType> {
    protected ArrayList<CategoryType> categoryList;
    protected HashMap<CategoryType, ArrayList<Integer>> dataPerCategory;

    public CategoryBalancedBaggingLearner() {
        this(null);
    }

    public CategoryBalancedBaggingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> batchLearner) {
        this(batchLearner, 100, 1.0d, new Random());
    }

    public CategoryBalancedBaggingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> batchLearner, int i, double d, Random random) {
        super(batchLearner, i, d, random);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.ensemble.AbstractBaggingLearner, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    public boolean initializeAlgorithm() {
        boolean initializeAlgorithm = super.initializeAlgorithm();
        if (initializeAlgorithm) {
            int size = this.dataList.size();
            Set findUniqueOutputs = DatasetUtil.findUniqueOutputs(this.dataList);
            this.categoryList = new ArrayList<>(findUniqueOutputs);
            this.dataPerCategory = new LinkedHashMap(findUniqueOutputs.size());
            Iterator it = findUniqueOutputs.iterator();
            while (it.hasNext()) {
                this.dataPerCategory.put(it.next(), new ArrayList());
            }
            for (int i = 0; i < size; i++) {
                this.dataPerCategory.get(((InputOutputPair) this.dataList.get(i)).getOutput()).add(Integer.valueOf(i));
            }
        }
        return initializeAlgorithm;
    }

    @Override // gov.sandia.cognition.learning.algorithm.ensemble.AbstractBaggingLearner
    protected void fillBag(int i) {
        int size = this.categoryList.size();
        if (i % size != 0) {
            Collections.shuffle(this.categoryList, this.random);
        }
        int i2 = i;
        for (int i3 = 0; i3 < size && i2 > 0; i3++) {
            ArrayList<Integer> arrayList = this.dataPerCategory.get(this.categoryList.get(i3));
            int size2 = arrayList.size();
            int max = Math.max(1, i2 / (size - i3));
            for (int i4 = 0; i4 < max; i4++) {
                int intValue = arrayList.get(this.random.nextInt(size2)).intValue();
                this.bag.add((InputOutputPair) this.dataList.get(intValue));
                int[] iArr = this.dataInBag;
                iArr[intValue] = iArr[intValue] + 1;
            }
            i2 -= max;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.ensemble.AbstractBaggingLearner, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    public void cleanupAlgorithm() {
        this.dataPerCategory = null;
        this.categoryList = null;
        super.cleanupAlgorithm();
    }
}
