package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.Distribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.DirichletDistribution;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;

@PublicationReference(author = {"Lawrence R. Rabiner"}, title = "A tutorial on hidden Markov models and selected applications in speech recognition", type = PublicationType.Journal, year = 1989, publication = "Proceedings of the IEEE", pages = {257, 286}, url = "http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf", notes = {"Rabiner's transition matrix is transposed from mine."})
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/hmm/HiddenMarkovModel.class */
public class HiddenMarkovModel<ObservationType> extends MarkovChain implements Distribution<ObservationType> {
    protected Collection<? extends ComputableDistribution<ObservationType>> emissionFunctions;

    public HiddenMarkovModel() {
    }

    public HiddenMarkovModel(int i) {
        super(i);
    }

    public HiddenMarkovModel(Vector vector, Matrix matrix, Collection<? extends ComputableDistribution<ObservationType>> collection) {
        super(vector, matrix);
        if (collection.size() != getNumStates()) {
            throw new IllegalArgumentException("Number of PDFs must be equal to number of states!");
        }
        setEmissionFunctions(collection);
    }

    public static <ObservationType> HiddenMarkovModel<ObservationType> createRandom(int i, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> batchLearner, Collection<? extends ObservationType> collection, Random random) {
        ArrayList arrayList = new ArrayList(collection.size());
        Iterator<? extends ObservationType> it = collection.iterator();
        while (it.hasNext()) {
            arrayList.add(new DefaultWeightedValue(it.next(), 1.0d));
        }
        return createRandom(i, (ComputableDistribution) batchLearner.learn(arrayList), random);
    }

    public static <ObservationType> HiddenMarkovModel<ObservationType> createRandom(int i, ComputableDistribution<ObservationType> computableDistribution, Random random) {
        return createRandom(Collections.nCopies(i, computableDistribution.getProbabilityFunction()), random);
    }

    public static <ObservationType> HiddenMarkovModel<ObservationType> createRandom(Collection<? extends ProbabilityFunction<ObservationType>> collection, Random random) {
        int size = collection.size();
        DirichletDistribution dirichletDistribution = new DirichletDistribution(size);
        Vector sample = dirichletDistribution.sample(random);
        Matrix createMatrix = MatrixFactory.getDefault().createMatrix(size, size);
        ArrayList<Vector> sample2 = dirichletDistribution.sample(random, size);
        for (int i = 0; i < size; i++) {
            createMatrix.setColumn(i, sample2.get(i));
        }
        return new HiddenMarkovModel<>(sample, createMatrix, collection);
    }

    @Override // gov.sandia.cognition.learning.algorithm.hmm.MarkovChain, gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public HiddenMarkovModel<ObservationType> mo0clone() {
        HiddenMarkovModel<ObservationType> hiddenMarkovModel = (HiddenMarkovModel) super.mo0clone();
        hiddenMarkovModel.setEmissionFunctions(ObjectUtil.cloneSmartElementsAsArrayList(getEmissionFunctions()));
        return hiddenMarkovModel;
    }

    public double computeObservationLogLikelihood(Collection<? extends ObservationType> collection) {
        Vector createVector = VectorFactory.getDefault().createVector(getNumStates());
        Vector mo0clone = getInitialProbability().mo0clone();
        Matrix transitionProbability = getTransitionProbability();
        int i = 0;
        double d = 0.0d;
        for (ObservationType observationtype : collection) {
            if (i > 0) {
                mo0clone = transitionProbability.times(mo0clone);
            }
            computeObservationLikelihoods(observationtype, createVector);
            mo0clone.dotTimesEquals(createVector);
            double norm1 = mo0clone.norm1();
            mo0clone.scaleEquals(1.0d / norm1);
            d += Math.log(norm1);
            i++;
        }
        return d;
    }

    protected double computeMultipleObservationLogLikelihood(Collection<? extends Collection<? extends ObservationType>> collection) {
        double d = 0.0d;
        Iterator<? extends Collection<? extends ObservationType>> it = collection.iterator();
        while (it.hasNext()) {
            d += computeObservationLogLikelihood(it.next());
        }
        return d;
    }

    public double computeObservationLogLikelihood(Collection<? extends ObservationType> collection, Collection<Integer> collection2) {
        if (collection.size() != collection2.size()) {
            throw new IllegalArgumentException("Observations and states must be the same size");
        }
        Iterator<Integer> it = collection2.iterator();
        double d = 0.0d;
        ArrayList arrayList = new ArrayList(getNumStates());
        Iterator<? extends ComputableDistribution<ObservationType>> it2 = getEmissionFunctions().iterator();
        while (it2.hasNext()) {
            arrayList.add(it2.next().getProbabilityFunction());
        }
        int i = -1;
        for (ObservationType observationtype : collection) {
            int intValue = it.next().intValue();
            i = intValue;
            d += ((ProbabilityFunction) arrayList.get(intValue)).logEvaluate(observationtype) + (i < 0 ? Math.log(this.initialProbability.getElement(intValue)) : Math.log(this.transitionProbability.getElement(intValue, i)));
        }
        return d;
    }

    @Override // gov.sandia.cognition.statistics.Distribution
    public ObservationType sample(Random random) {
        return (ObservationType) CollectionUtil.getFirst(sample(random, 1));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.statistics.Distribution
    public ArrayList<ObservationType> sample(Random random, int i) {
        ArrayList<ObservationType> arrayList = new ArrayList<>(i);
        Vector initialProbability = getInitialProbability();
        for (int i2 = 0; i2 < i; i2++) {
            double nextDouble = random.nextDouble();
            int i3 = -1;
            while (nextDouble > 0.0d) {
                i3++;
                nextDouble -= initialProbability.getElement(i3);
            }
            arrayList.add(((ComputableDistribution) CollectionUtil.getElement(getEmissionFunctions(), i3)).sample(random));
            initialProbability = getTransitionProbability().getColumn(i3);
        }
        return arrayList;
    }

    public Collection<? extends ComputableDistribution<ObservationType>> getEmissionFunctions() {
        return this.emissionFunctions;
    }

    public void setEmissionFunctions(Collection<? extends ComputableDistribution<ObservationType>> collection) {
        this.emissionFunctions = collection;
    }

    protected WeightedValue<Vector> computeForwardProbabilities(Vector vector, Vector vector2, boolean z) {
        double d;
        Vector times = getTransitionProbability().times(vector);
        times.dotTimesEquals(vector2);
        if (z) {
            d = 1.0d / times.norm1();
            times.scaleEquals(d);
        } else {
            d = 1.0d;
        }
        return new DefaultWeightedValue(times, d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ArrayList<WeightedValue<Vector>> computeForwardProbabilities(ArrayList<Vector> arrayList, boolean z) {
        double d;
        int size = arrayList.size();
        ArrayList<WeightedValue<Vector>> arrayList2 = new ArrayList<>(size);
        Vector vector = (Vector) arrayList.get(0).dotTimes(getInitialProbability());
        if (z) {
            d = 1.0d / vector.norm1();
            vector.scaleEquals(d);
        } else {
            d = 1.0d;
        }
        WeightedValue<Vector> defaultWeightedValue = new DefaultWeightedValue(vector, d);
        arrayList2.add(defaultWeightedValue);
        for (int i = 1; i < size; i++) {
            defaultWeightedValue = computeForwardProbabilities(defaultWeightedValue.getValue(), arrayList.get(i), z);
            arrayList2.add(defaultWeightedValue);
        }
        return arrayList2;
    }

    public Vector computeObservationLikelihoods(ObservationType observationtype) {
        Vector createVector = VectorFactory.getDefault().createVector(getEmissionFunctions().size());
        computeObservationLikelihoods(observationtype, createVector);
        return createVector;
    }

    protected void computeObservationLikelihoods(ObservationType observationtype, Vector vector) {
        int i = 0;
        Iterator<? extends ComputableDistribution<ObservationType>> it = getEmissionFunctions().iterator();
        while (it.hasNext()) {
            vector.setElement(i, it.next().getProbabilityFunction().evaluate(observationtype).doubleValue());
            i++;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ArrayList<Vector> computeObservationLikelihoods(Collection<? extends ObservationType> collection) {
        ArrayList<Vector> arrayList = new ArrayList<>(collection.size());
        Iterator<? extends ObservationType> it = collection.iterator();
        while (it.hasNext()) {
            arrayList.add(computeObservationLikelihoods((HiddenMarkovModel<ObservationType>) it.next()));
        }
        return arrayList;
    }

    protected WeightedValue<Vector> computeBackwardProbabilities(Vector vector, Vector vector2, double d) {
        Vector times = ((Vector) vector2.dotTimes(vector)).times(getTransitionProbability());
        if (d != 1.0d) {
            times.scaleEquals(d);
        }
        return new DefaultWeightedValue(times, d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ArrayList<WeightedValue<Vector>> computeBackwardProbabilities(ArrayList<Vector> arrayList, ArrayList<WeightedValue<Vector>> arrayList2) {
        int size = arrayList.size();
        int dimensionality = getInitialProbability().getDimensionality();
        ArrayList<WeightedValue<Vector>> arrayList3 = new ArrayList<>(size);
        for (int i = 0; i < size; i++) {
            arrayList3.add(null);
        }
        Vector createVector = VectorFactory.getDefault().createVector(dimensionality, 1.0d);
        double weight = arrayList2.get(size - 1).getWeight();
        if (weight != 1.0d) {
            createVector.scaleEquals(weight);
        }
        WeightedValue<Vector> defaultWeightedValue = new DefaultWeightedValue(createVector, weight);
        arrayList3.set(size - 1, defaultWeightedValue);
        for (int i2 = size - 2; i2 >= 0; i2--) {
            defaultWeightedValue = computeBackwardProbabilities(defaultWeightedValue.getValue(), arrayList.get(i2 + 1), arrayList2.get(i2).getWeight());
            arrayList3.set(i2, defaultWeightedValue);
        }
        return arrayList3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Vector computeStateObservationLikelihood(Vector vector, Vector vector2, double d) {
        Vector vector3 = (Vector) vector.dotTimes(vector2);
        vector3.scaleEquals(d / vector3.norm1());
        return vector3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ArrayList<Vector> computeStateObservationLikelihood(ArrayList<WeightedValue<Vector>> arrayList, ArrayList<WeightedValue<Vector>> arrayList2, double d) {
        int size = arrayList.size();
        ArrayList<Vector> arrayList3 = new ArrayList<>(size);
        for (int i = 0; i < size; i++) {
            arrayList3.add(computeStateObservationLikelihood(arrayList.get(i).getValue(), arrayList2.get(i).getValue(), d));
        }
        return arrayList3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Matrix computeTransitions(Vector vector, Vector vector2, Vector vector3) {
        return ((Vector) vector3.dotTimes(vector2)).outerProduct(vector);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix computeTransitions(ArrayList<WeightedValue<Vector>> arrayList, ArrayList<WeightedValue<Vector>> arrayList2, ArrayList<Vector> arrayList3) {
        int size = arrayList3.size();
        RingAccumulator ringAccumulator = new RingAccumulator();
        for (int i = 0; i < size - 1; i++) {
            ringAccumulator.accumulate((RingAccumulator) computeTransitions(arrayList.get(i).getValue(), arrayList2.get(i + 1).getValue(), arrayList3.get(i + 1)));
        }
        Matrix matrix = (Matrix) ringAccumulator.getSum();
        matrix.dotTimesEquals(getTransitionProbability());
        normalizeTransitionMatrix(matrix);
        return matrix;
    }

    @Override // gov.sandia.cognition.learning.algorithm.hmm.MarkovChain
    public String toString() {
        StringBuilder sb = new StringBuilder(super.toString());
        for (ComputableDistribution<ObservationType> computableDistribution : getEmissionFunctions()) {
            sb.append("F: ");
            sb.append(computableDistribution.toString());
        }
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public WeightedValue<Integer> findMostLikelyState(int i, Vector vector) {
        double d = Double.NEGATIVE_INFINITY;
        int i2 = -1;
        int dimensionality = vector.getDimensionality();
        for (int i3 = 0; i3 < dimensionality; i3++) {
            double element = this.transitionProbability.getElement(i, i3) * vector.getElement(i3);
            if (d < element) {
                d = element;
                i2 = i3;
            }
        }
        return new DefaultWeightedValue(Integer.valueOf(i2), d);
    }

    protected Pair<Vector, int[]> computeViterbiRecursion(Vector vector, Vector vector2) {
        int dimensionality = vector.getDimensionality();
        Vector createVector = VectorFactory.getDefault().createVector(dimensionality);
        int[] iArr = new int[dimensionality];
        for (int i = 0; i < dimensionality; i++) {
            WeightedValue<Integer> findMostLikelyState = findMostLikelyState(i, vector);
            iArr[i] = findMostLikelyState.getValue().intValue();
            createVector.setElement(i, findMostLikelyState.getWeight());
        }
        createVector.dotTimesEquals(vector2);
        createVector.scaleEquals(1.0d / createVector.norm1());
        return DefaultPair.create(createVector, iArr);
    }

    @PublicationReference(author = {"Wikipedia"}, title = "Viterbi algorithm", year = 2010, type = PublicationType.WebPage, url = "http://en.wikipedia.org/wiki/Viterbi_algorithm")
    public ArrayList<Integer> viterbi(Collection<? extends ObservationType> collection) {
        int size = collection.size();
        int numStates = getNumStates();
        ArrayList<Vector> computeObservationLikelihoods = computeObservationLikelihoods((Collection) collection);
        Vector vector = (Vector) getInitialProbability().dotTimes(computeObservationLikelihoods.get(0));
        ArrayList arrayList = new ArrayList(size);
        int[] iArr = new int[numStates];
        for (int i = 0; i < numStates; i++) {
            iArr[i] = 0;
        }
        arrayList.add(iArr);
        ArrayList<Integer> arrayList2 = new ArrayList<>(size);
        arrayList2.add(null);
        for (int i2 = 1; i2 < size; i2++) {
            arrayList2.add(null);
            Pair<Vector, int[]> computeViterbiRecursion = computeViterbiRecursion(vector, computeObservationLikelihoods.get(i2));
            vector = computeViterbiRecursion.getFirst();
            arrayList.add(computeViterbiRecursion.getSecond());
        }
        int i3 = -1;
        double d = Double.NEGATIVE_INFINITY;
        for (int i4 = 0; i4 < numStates; i4++) {
            double element = vector.getElement(i4);
            if (d < element) {
                d = element;
                i3 = i4;
            }
        }
        int i5 = i3;
        arrayList2.set(size - 1, Integer.valueOf(i5));
        for (int i6 = size - 2; i6 >= 0; i6--) {
            i5 = ((int[]) arrayList.get(i6 + 1))[i5];
            arrayList2.set(i6, Integer.valueOf(i5));
        }
        return arrayList2;
    }

    public ArrayList<Vector> stateBeliefs(Collection<? extends ObservationType> collection) {
        ArrayList<WeightedValue<Vector>> computeForwardProbabilities = computeForwardProbabilities(computeObservationLikelihoods((Collection) collection), true);
        ArrayList<Vector> arrayList = new ArrayList<>(computeForwardProbabilities.size());
        Iterator<WeightedValue<Vector>> it = computeForwardProbabilities.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getValue());
        }
        return arrayList;
    }
}
