package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;

@PublicationReference(author = {"William Turin"}, title = "Unidirectional and Parallel Baum–Welch Algorithms", type = PublicationType.Journal, publication = "IEEE Transactions on Speech and Audio Processing", year = 1998, url = "http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=00725318")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/ParallelHiddenMarkovModel.class */
public class ParallelHiddenMarkovModel<ObservationType> extends HiddenMarkovModel<ObservationType> implements ParallelAlgorithm {
    private transient ThreadPoolExecutor threadPool;
    protected transient ArrayList<ObservationLikelihoodTask<ObservationType>> observationLikelihoodTasks;
    protected transient ArrayList<ComputeTransitionsTask> computeTransitionTasks;
    protected transient ArrayList<NormalizeTransitionTask> normalizeTransitionTasks;
    protected transient ArrayList<StateObservationLikelihoodTask> stateObservationLikelihoodTasks;
    protected transient ArrayList<ParallelHiddenMarkovModel<ObservationType>.ViterbiTask> viterbiTasks;

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/ParallelHiddenMarkovModel$ComputeTransitionsTask.class */
    protected static class ComputeTransitionsTask extends AbstractCloneableSerializable implements Callable<Matrix> {
        Vector alphan;
        Vector betanp1;
        Vector bnp1;

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Matrix call() {
            return ParallelHiddenMarkovModel.computeTransitions(this.alphan, this.betanp1, this.bnp1);
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/ParallelHiddenMarkovModel$LogLikelihoodTask.class */
    protected class LogLikelihoodTask extends AbstractCloneableSerializable implements Callable<Double> {
        protected Collection<? extends ObservationType> data;

        public LogLikelihoodTask(Collection<? extends ObservationType> collection) {
            this.data = collection;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            return Double.valueOf(ParallelHiddenMarkovModel.this.computeObservationLogLikelihood(this.data));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/ParallelHiddenMarkovModel$NormalizeTransitionTask.class */
    public static class NormalizeTransitionTask extends AbstractCloneableSerializable implements Callable<Void> {
        private Matrix A;
        private int j;

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Void call() {
            ParallelHiddenMarkovModel.normalizeTransitionMatrix(this.A, this.j);
            return null;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/ParallelHiddenMarkovModel$ObservationLikelihoodTask.class */
    protected static class ObservationLikelihoodTask<ObservationType> extends AbstractCloneableSerializable implements Callable<double[]> {
        protected Collection<? extends ObservationType> observations;
        protected ProbabilityFunction<ObservationType> distributionFunction;

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public double[] call() {
            double[] dArr = new double[this.observations.size()];
            int i = 0;
            Iterator<? extends ObservationType> it = this.observations.iterator();
            while (it.hasNext()) {
                dArr[i] = this.distributionFunction.evaluate(it.next()).doubleValue();
                i++;
            }
            return dArr;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/ParallelHiddenMarkovModel$StateObservationLikelihoodTask.class */
    protected static class StateObservationLikelihoodTask extends AbstractCloneableSerializable implements Callable<Vector> {
        protected Vector alpha;
        protected Vector beta;

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Vector call() throws Exception {
            return ParallelHiddenMarkovModel.computeStateObservationLikelihood(this.alpha, this.beta, 1.0d);
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/ParallelHiddenMarkovModel$ViterbiTask.class */
    protected class ViterbiTask extends AbstractCloneableSerializable implements Callable<WeightedValue<Integer>> {
        int destinationState;
        Vector delta;

        ViterbiTask() {
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public WeightedValue<Integer> call() throws Exception {
            return ParallelHiddenMarkovModel.this.findMostLikelyState(this.destinationState, this.delta);
        }
    }

    public ParallelHiddenMarkovModel() {
    }

    public ParallelHiddenMarkovModel(int i) {
        super(i);
    }

    public ParallelHiddenMarkovModel(Vector vector, Matrix matrix, Collection<? extends ComputableDistribution<ObservationType>> collection) {
        super(vector, matrix, collection);
    }

    public ParallelHiddenMarkovModel(HiddenMarkovModel<ObservationType> hiddenMarkovModel) {
        this((Vector) ObjectUtil.cloneSafe(hiddenMarkovModel.getInitialProbability()), (Matrix) ObjectUtil.cloneSafe(hiddenMarkovModel.getTransitionProbability()), ObjectUtil.cloneSmartElementsAsArrayList(hiddenMarkovModel.getEmissionFunctions()));
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            setThreadPool(ParallelUtil.createThreadPool());
        }
        return this.threadPool;
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public void setThreadPool(ThreadPoolExecutor threadPoolExecutor) {
        this.threadPool = threadPoolExecutor;
    }

    @Override // gov.sandia.cognition.algorithm.ParallelAlgorithm
    public int getNumThreads() {
        return ParallelUtil.getNumThreads(this);
    }

    @Override // gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel
    public double computeMultipleObservationLogLikelihood(Collection<? extends Collection<? extends ObservationType>> collection) {
        ArrayList arrayList = new ArrayList(collection.size());
        Iterator<? extends Collection<? extends ObservationType>> it = collection.iterator();
        while (it.hasNext()) {
            arrayList.add(new LogLikelihoodTask(it.next()));
        }
        try {
            ArrayList executeInParallel = ParallelUtil.executeInParallel(arrayList, getThreadPool());
            double d = 0.0d;
            for (int i = 0; i < executeInParallel.size(); i++) {
                d += ((Double) executeInParallel.get(i)).doubleValue();
            }
            return d;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel
    public Matrix computeTransitions(ArrayList<WeightedValue<Vector>> arrayList, ArrayList<WeightedValue<Vector>> arrayList2, ArrayList<Vector> arrayList3) {
        int size = arrayList.size();
        if (this.computeTransitionTasks == null) {
            this.computeTransitionTasks = new ArrayList<>(size - 1);
        }
        this.computeTransitionTasks.ensureCapacity(size - 1);
        while (this.computeTransitionTasks.size() > size - 1) {
            this.computeTransitionTasks.remove(this.computeTransitionTasks.size() - 1);
        }
        while (this.computeTransitionTasks.size() < size - 1) {
            this.computeTransitionTasks.add(new ComputeTransitionsTask());
        }
        for (int i = 0; i < size - 1; i++) {
            ComputeTransitionsTask computeTransitionsTask = this.computeTransitionTasks.get(i);
            computeTransitionsTask.alphan = arrayList.get(i).getValue();
            computeTransitionsTask.betanp1 = arrayList2.get(i + 1).getValue();
            computeTransitionsTask.bnp1 = arrayList3.get(i + 1);
        }
        RingAccumulator ringAccumulator = new RingAccumulator();
        try {
            Iterator it = getThreadPool().invokeAll(this.computeTransitionTasks).iterator();
            while (it.hasNext()) {
                ringAccumulator.accumulate((RingAccumulator) ((Future) it.next()).get());
            }
            Matrix matrix = (Matrix) ringAccumulator.getSum();
            matrix.dotTimesEquals(getTransitionProbability());
            normalizeTransitionMatrix(matrix);
            return matrix;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.hmm.MarkovChain
    public void normalizeTransitionMatrix(Matrix matrix) {
        int numColumns = matrix.getNumColumns();
        if (this.normalizeTransitionTasks == null) {
            this.normalizeTransitionTasks = new ArrayList<>(numColumns);
        }
        this.normalizeTransitionTasks.ensureCapacity(numColumns);
        while (this.normalizeTransitionTasks.size() > numColumns) {
            this.normalizeTransitionTasks.remove(this.normalizeTransitionTasks.size() - 1);
        }
        while (this.normalizeTransitionTasks.size() < numColumns) {
            this.normalizeTransitionTasks.add(new NormalizeTransitionTask());
        }
        for (int i = 0; i < numColumns; i++) {
            NormalizeTransitionTask normalizeTransitionTask = this.normalizeTransitionTasks.get(i);
            normalizeTransitionTask.j = i;
            normalizeTransitionTask.A = matrix;
        }
        try {
            ParallelUtil.executeInParallel(this.normalizeTransitionTasks, getThreadPool());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel
    public ArrayList<Vector> computeStateObservationLikelihood(ArrayList<WeightedValue<Vector>> arrayList, ArrayList<WeightedValue<Vector>> arrayList2, double d) {
        int size = arrayList.size();
        if (this.stateObservationLikelihoodTasks == null) {
            this.stateObservationLikelihoodTasks = new ArrayList<>(size);
        }
        this.stateObservationLikelihoodTasks.ensureCapacity(size);
        while (this.stateObservationLikelihoodTasks.size() > size) {
            this.stateObservationLikelihoodTasks.remove(this.stateObservationLikelihoodTasks.size() - 1);
        }
        while (this.stateObservationLikelihoodTasks.size() < size) {
            this.stateObservationLikelihoodTasks.add(new StateObservationLikelihoodTask());
        }
        for (int i = 0; i < size; i++) {
            StateObservationLikelihoodTask stateObservationLikelihoodTask = this.stateObservationLikelihoodTasks.get(i);
            stateObservationLikelihoodTask.alpha = arrayList.get(i).getValue();
            stateObservationLikelihoodTask.beta = arrayList2.get(i).getValue();
        }
        try {
            return ParallelUtil.executeInParallel(this.stateObservationLikelihoodTasks, getThreadPool());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel
    protected Pair<Vector, int[]> computeViterbiRecursion(Vector vector, Vector vector2) {
        int numStates = getNumStates();
        if (this.viterbiTasks == null) {
            this.viterbiTasks = new ArrayList<>(numStates);
        }
        this.viterbiTasks.ensureCapacity(numStates);
        while (this.viterbiTasks.size() > numStates) {
            this.viterbiTasks.remove(this.viterbiTasks.size() - 1);
        }
        while (this.viterbiTasks.size() < numStates) {
            this.viterbiTasks.add(new ViterbiTask());
        }
        for (int i = 0; i < numStates; i++) {
            ParallelHiddenMarkovModel<ObservationType>.ViterbiTask viterbiTask = this.viterbiTasks.get(i);
            viterbiTask.destinationState = i;
            viterbiTask.delta = vector;
        }
        try {
            ArrayList executeInParallel = ParallelUtil.executeInParallel(this.viterbiTasks, getThreadPool());
            int[] iArr = new int[numStates];
            Vector createVector = VectorFactory.getDefault().createVector(numStates);
            for (int i2 = 0; i2 < numStates; i2++) {
                WeightedValue weightedValue = (WeightedValue) executeInParallel.get(i2);
                iArr[i2] = ((Integer) weightedValue.getValue()).intValue();
                createVector.setElement(i2, weightedValue.getWeight());
            }
            createVector.dotTimesEquals(vector2);
            createVector.scaleEquals(1.0d / createVector.norm1());
            return DefaultPair.create(createVector, iArr);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
