package org.broadinstitute.hellbender.utils.hmm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.RandomGeneratorFactory;
import org.apache.commons.math3.util.FastMath;
import org.broadinstitute.hellbender.tools.walkers.genotyper.StandardCallerArgumentCollection;
import org.broadinstitute.hellbender.utils.MathUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/utils/hmm/HMM.class */
public interface HMM<D, T, S> {
    public static final int RANDOM_SEED_FOR_CHAIN_GENERATION = 13;

    List<S> hiddenStates();

    double logPriorProbability(S s, T t);

    double logTransitionProbability(S s, T t, S s2, T t2);

    double logEmissionProbability(D d, S s, T t);

    default List<S> generateHiddenStateChain(List<T> list) {
        RandomGenerator createRandomGenerator = RandomGeneratorFactory.createRandomGenerator(new Random(13L));
        List<S> hiddenStates = hiddenStates();
        ArrayList arrayList = new ArrayList(list.size());
        arrayList.add(MathUtils.randomSelect(hiddenStates, obj -> {
            return Double.valueOf(Math.exp(logPriorProbability(obj, list.get(0))));
        }, createRandomGenerator));
        IntStream.range(1, list.size()).forEach(i -> {
            arrayList.add(MathUtils.randomSelect(hiddenStates, obj2 -> {
                return Double.valueOf(Math.exp(logTransitionProbability(arrayList.get(i - 1), list.get(i - 1), obj2, list.get(i))));
            }, createRandomGenerator));
        });
        return arrayList;
    }

    default double calculateLogChainPriorProbability(List<T> list) {
        List<S> hiddenStates = hiddenStates();
        if (list.isEmpty() || hiddenStates.isEmpty()) {
            return StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION;
        }
        int size = hiddenStates.size();
        int size2 = list.size();
        double[] array = hiddenStates.stream().mapToDouble(obj -> {
            return logPriorProbability(obj, list.get(0));
        }).toArray();
        double[] array2 = Arrays.stream(array).map(FastMath::exp).toArray();
        double sum = IntStream.range(0, size).mapToDouble(i -> {
            return array[i] * array2[i];
        }).sum();
        if (size2 > 1) {
            for (int i2 = 0; i2 < size2 - 1; i2++) {
                for (int i3 = 0; i3 < size; i3++) {
                    for (int i4 = 0; i4 < size; i4++) {
                        double logTransitionProbability = logTransitionProbability(hiddenStates.get(i3), list.get(i2), hiddenStates.get(i4), list.get(i2 + 1));
                        if (logTransitionProbability != Double.NEGATIVE_INFINITY) {
                            sum += logTransitionProbability * array2[i3] * array2[i4];
                        }
                    }
                }
            }
        }
        return sum;
    }
}
