package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;

/* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/ParallelDirichletProcessMixtureModel.class */
public class ParallelDirichletProcessMixtureModel<ObservationType> extends DirichletProcessMixtureModel<ObservationType> implements ParallelAlgorithm {
    private transient ThreadPoolExecutor threadPool;
    protected transient ArrayList<ParallelDirichletProcessMixtureModel<ObservationType>.ObservationAssignmentTask> assignmentTasks;
    protected transient ArrayList<ParallelDirichletProcessMixtureModel<ObservationType>.ClusterUpdaterTask> clusterUpdaterTasks;

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/ParallelDirichletProcessMixtureModel$ClusterUpdaterTask.class */
    protected class ClusterUpdaterTask extends AbstractCloneableSerializable implements Callable<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> {
        Collection<ObservationType> observations;
        DirichletProcessMixtureModel.Updater<ObservationType> localUpdater;

        public ClusterUpdaterTask() {
            this.localUpdater = (DirichletProcessMixtureModel.Updater) ObjectUtil.cloneSafe(ParallelDirichletProcessMixtureModel.this.updater);
        }

        @Override // java.util.concurrent.Callable
        public DirichletProcessMixtureModel.DPMMCluster<ObservationType> call() {
            return ParallelDirichletProcessMixtureModel.this.createCluster(this.observations, this.localUpdater);
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/ParallelDirichletProcessMixtureModel$DPMMAssignments.class */
    public static class DPMMAssignments {
        protected ArrayList<Integer> assignments;
        protected DirichletProcessMixtureModel.DPMMLogConditional logConditional;

        public DPMMAssignments(ArrayList<Integer> arrayList, DirichletProcessMixtureModel.DPMMLogConditional dPMMLogConditional) {
            this.assignments = arrayList;
            this.logConditional = dPMMLogConditional;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/ParallelDirichletProcessMixtureModel$ObservationAssignmentTask.class */
    protected class ObservationAssignmentTask extends AbstractCloneableSerializable implements Callable<DPMMAssignments> {
        private Collection<? extends ObservationType> observations;
        private double[] weights = null;
        private ArrayList<Integer> assignments;
        private DirichletProcessMixtureModel.DPMMLogConditional logConditional;

        public ObservationAssignmentTask(Collection<? extends ObservationType> collection) {
            this.observations = collection;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        /* JADX WARN: Multi-variable type inference failed */
        @Override // java.util.concurrent.Callable
        public DPMMAssignments call() throws Exception {
            int numClusters = ((DirichletProcessMixtureModel.Sample) ParallelDirichletProcessMixtureModel.this.currentParameter).getNumClusters();
            if (this.weights == null || this.weights.length != numClusters + 1) {
                this.weights = new double[numClusters + 1];
            }
            if (this.assignments == null) {
                this.assignments = new ArrayList<>(this.observations.size());
                for (int i = 0; i < this.observations.size(); i++) {
                    this.assignments.add(null);
                }
            }
            this.logConditional = new DirichletProcessMixtureModel.DPMMLogConditional();
            int i2 = 0;
            Iterator<? extends ObservationType> it = this.observations.iterator();
            while (it.hasNext()) {
                this.assignments.set(i2, Integer.valueOf(ParallelDirichletProcessMixtureModel.this.assignObservationToCluster(it.next(), this.weights, this.logConditional)));
                i2++;
            }
            return new DPMMAssignments(this.assignments, this.logConditional);
        }
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public int getNumThreads() {
        return ParallelUtil.getNumThreads(this);
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            setThreadPool(ParallelUtil.createThreadPool());
        }
        return this.threadPool;
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public void setThreadPool(ThreadPoolExecutor threadPoolExecutor) {
        this.threadPool = threadPoolExecutor;
    }

    @Override // gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel
    protected ArrayList<Collection<ObservationType>> assignObservationsToClusters(int i, DirichletProcessMixtureModel.DPMMLogConditional dPMMLogConditional) {
        if (this.assignmentTasks == null) {
            ArrayList asArrayList = CollectionUtil.asArrayList((Iterable) this.data);
            int size = asArrayList.size();
            int numThreads = getNumThreads();
            this.assignmentTasks = new ArrayList<>(numThreads);
            int i2 = size / numThreads;
            int i3 = 0;
            for (int i4 = 0; i4 < numThreads - 1; i4++) {
                int i5 = i3;
                i3 += i2;
                this.assignmentTasks.add(new ObservationAssignmentTask(asArrayList.subList(i5, i3)));
            }
            this.assignmentTasks.add(new ObservationAssignmentTask(asArrayList.subList(i3, size)));
        }
        try {
            ArrayList executeInParallel = ParallelUtil.executeInParallel(this.assignmentTasks, getThreadPool());
            ArrayList<Collection<ObservationType>> arrayList = new ArrayList<>(i + 1);
            for (int i6 = 0; i6 < i + 1; i6++) {
                arrayList.add(new LinkedList());
            }
            for (int i7 = 0; i7 < executeInParallel.size(); i7++) {
                dPMMLogConditional.logConditional += ((DPMMAssignments) executeInParallel.get(i7)).logConditional.logConditional;
                ArrayList<Integer> arrayList2 = ((DPMMAssignments) executeInParallel.get(i7)).assignments;
                int i8 = 0;
                Iterator it = ((ObservationAssignmentTask) this.assignmentTasks.get(i7)).observations.iterator();
                while (it.hasNext()) {
                    arrayList.get(arrayList2.get(i8).intValue()).add(it.next());
                    i8++;
                }
            }
            return arrayList;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel
    protected ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> updateClusters(ArrayList<Collection<ObservationType>> arrayList) {
        int size = arrayList.size();
        if (this.clusterUpdaterTasks == null || this.clusterUpdaterTasks.size() != size) {
            this.clusterUpdaterTasks = new ArrayList<>(size);
            for (int i = 0; i < size; i++) {
                this.clusterUpdaterTasks.add(new ClusterUpdaterTask());
            }
        }
        for (int i2 = 0; i2 < size; i2++) {
            Collection<ObservationType> collection = arrayList.get(i2);
            if (collection.size() <= 1) {
                collection = null;
            }
            this.clusterUpdaterTasks.get(i2).observations = collection;
        }
        try {
            ArrayList executeInParallel = ParallelUtil.executeInParallel(this.clusterUpdaterTasks, getThreadPool());
            ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> arrayList2 = new ArrayList<>(size);
            for (int i3 = 0; i3 < size; i3++) {
                DirichletProcessMixtureModel.DPMMCluster<ObservationType> dPMMCluster = (DirichletProcessMixtureModel.DPMMCluster) executeInParallel.get(i3);
                if (dPMMCluster != null) {
                    arrayList2.add(dPMMCluster);
                }
            }
            return arrayList2;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
