package gov.sandia.cognition.learning.function.cost;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;

/* loaded from: input_file:gov/sandia/cognition/learning/function/cost/ParallelNegativeLogLikelihood.class */
public class ParallelNegativeLogLikelihood<DataType> extends NegativeLogLikelihood<DataType> implements ParallelAlgorithm {
    protected transient ThreadPoolExecutor threadPool;
    protected transient ArrayList<NegativeLogLikelihoodTask<DataType>> tasks;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:gov/sandia/cognition/learning/function/cost/ParallelNegativeLogLikelihood$NegativeLogLikelihoodTask.class */
    public static class NegativeLogLikelihoodTask<DataType> implements Callable<Double> {
        private Collection<? extends DataType> data;
        protected ProbabilityFunction<DataType> probabilityFunction;

        public NegativeLogLikelihoodTask(Collection<? extends DataType> collection) {
            this.data = collection;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            return Double.valueOf(this.data.size() * NegativeLogLikelihood.evaluate(this.probabilityFunction, this.data));
        }
    }

    public ParallelNegativeLogLikelihood() {
        this(null);
    }

    public ParallelNegativeLogLikelihood(Collection<? extends DataType> collection) {
        super(collection);
    }

    @Override // gov.sandia.cognition.learning.function.cost.NegativeLogLikelihood, gov.sandia.cognition.evaluator.Evaluator
    public Double evaluate(ComputableDistribution<DataType> computableDistribution) {
        ProbabilityFunction<DataType> probabilityFunction = computableDistribution.getProbabilityFunction();
        int size = ((Collection) this.costParameters).size();
        int numThreads = getNumThreads();
        if (this.tasks == null || this.tasks.size() != numThreads) {
            ArrayList asArrayList = CollectionUtil.asArrayList((Iterable) this.costParameters);
            this.tasks = new ArrayList<>(numThreads);
            int i = size / numThreads;
            int i2 = 0;
            for (int i3 = 0; i3 < numThreads; i3++) {
                int i4 = i2;
                i2 += i;
                if (i3 == numThreads - 1) {
                    i2 = size;
                }
                this.tasks.add(new NegativeLogLikelihoodTask<>(asArrayList.subList(i4, i2)));
            }
        }
        for (int i5 = 0; i5 < numThreads; i5++) {
            this.tasks.get(i5).probabilityFunction = probabilityFunction;
        }
        try {
            return Double.valueOf(UnivariateStatisticsUtil.computeSum(ParallelUtil.executeInParallel(this.tasks, getThreadPool())) / ((Collection) this.costParameters).size());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            this.threadPool = ParallelUtil.createThreadPool();
        }
        return this.threadPool;
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public void setThreadPool(ThreadPoolExecutor threadPoolExecutor) {
        this.threadPool = threadPoolExecutor;
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public int getNumThreads() {
        return ParallelUtil.getNumThreads(this);
    }
}
