package gov.sandia.cognition.learning.experiment;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.learning.performance.PerformanceEvaluator;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Summarizer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:gov/sandia/cognition/learning/experiment/ParallelLearnerValidationExperiment.class */
public class ParallelLearnerValidationExperiment<InputDataType, FoldDataType, LearnedType, StatisticType, SummaryType> extends LearnerValidationExperiment<InputDataType, FoldDataType, LearnedType, StatisticType, SummaryType> implements ParallelAlgorithm {
    private transient ThreadPoolExecutor threadPool;

    /* loaded from: input_file:gov/sandia/cognition/learning/experiment/ParallelLearnerValidationExperiment$TrialTask.class */
    private class TrialTask implements Callable<StatisticType> {
        private PartitionedDataset<FoldDataType> fold;

        public TrialTask(PartitionedDataset<FoldDataType> partitionedDataset) {
            this.fold = partitionedDataset;
        }

        @Override // java.util.concurrent.Callable
        public StatisticType call() {
            try {
                ParallelLearnerValidationExperiment.this.fireTrialStarted();
                StatisticType evaluatePerformance = ParallelLearnerValidationExperiment.this.getPerformanceEvaluator().evaluatePerformance(((BatchLearner) ObjectUtil.cloneSmart(ParallelLearnerValidationExperiment.this.getLearner())).learn(this.fold.getTrainingSet()), this.fold.getTestingSet());
                ParallelLearnerValidationExperiment.this.fireTrialEnded();
                return evaluatePerformance;
            } catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }
    }

    public ParallelLearnerValidationExperiment() {
        this(null, null, null);
    }

    public ParallelLearnerValidationExperiment(ValidationFoldCreator<InputDataType, FoldDataType> validationFoldCreator, PerformanceEvaluator<? super LearnedType, Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator, Summarizer<? super StatisticType, ? extends SummaryType> summarizer) {
        super(validationFoldCreator, performanceEvaluator, summarizer);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.experiment.AbstractValidationFoldExperiment
    public void runExperiment(Collection<PartitionedDataset<FoldDataType>> collection) {
        setNumTrials(collection.size());
        fireExperimentStarted();
        LinkedList linkedList = new LinkedList();
        Iterator<PartitionedDataset<FoldDataType>> it = collection.iterator();
        while (it.hasNext()) {
            linkedList.add(new TrialTask(it.next()));
        }
        ArrayList arrayList = null;
        try {
            arrayList = ParallelUtil.executeInParallel(linkedList, getThreadPool());
        } catch (Exception e) {
            Logger.getLogger(ParallelLearnerValidationExperiment.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        getStatistics().addAll(arrayList);
        fireExperimentEnded();
    }

    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            this.threadPool = ParallelUtil.createThreadPool();
        }
        return this.threadPool;
    }

    public void setThreadPool(ThreadPoolExecutor threadPoolExecutor) {
        this.threadPool = threadPoolExecutor;
    }

    public int getNumThreads() {
        return ParallelUtil.getNumThreads(this);
    }
}
