package org.broadinstitute.hellbender.utils.hmm;

import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang.math.IntRange;
import org.apache.commons.math3.util.FastMath;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/utils/hmm/ForwardBackwardAlgorithm.class */
public final class ForwardBackwardAlgorithm {

    /* loaded from: input_file:org/broadinstitute/hellbender/utils/hmm/ForwardBackwardAlgorithm$ArrayResult.class */
    private static class ArrayResult<D, T, S> implements Result<D, T, S>, Serializable {
        private static final long serialVersionUID = -8556604447304292642L;
        private final List<D> data;
        private final List<T> positions;
        private final IntRange positionIndexRange;
        private final HMM<D, T, S> model;
        private final Object2IntMap<T> positionIndex;
        private final Object2IntMap<S> stateIndex;
        private final double[][] logForwardProbabilities;
        private final double[][] logBackwardProbabilities;
        private final double[] logDataLikelihood;

        private ArrayResult(List<D> list, List<T> list2, HMM<D, T, S> hmm, double[][] dArr, double[][] dArr2) {
            this.data = Collections.unmodifiableList(new ArrayList(list));
            this.positions = Collections.unmodifiableList(new ArrayList(list2));
            this.positionIndexRange = new IntRange(0, list2.size() - 1);
            this.model = hmm;
            this.positionIndex = (Object2IntMap<T>) composeIndexMap(this.positions);
            this.stateIndex = (Object2IntMap<S>) composeIndexMap(hmm.hiddenStates());
            this.logBackwardProbabilities = dArr2;
            this.logForwardProbabilities = dArr;
            this.logDataLikelihood = calculateLogDataLikelihood(dArr, dArr2);
        }

        private static double[] calculateLogDataLikelihood(double[][] dArr, double[][] dArr2) {
            return IntStream.range(0, dArr.length).mapToObj(i -> {
                return IntStream.range(0, dArr[i].length).mapToDouble(i -> {
                    return dArr2[i][i] + dArr[i][i];
                }).toArray();
            }).mapToDouble(MathUtils::logSumExp).toArray();
        }

        private <E> Object2IntMap<E> composeIndexMap(List<E> list) {
            return (Object2IntMap) IntStream.range(0, list.size()).collect(() -> {
                return new Object2IntOpenHashMap(list.size());
            }, (object2IntOpenHashMap, i) -> {
                object2IntOpenHashMap.put(list.get(i), i);
            }, (object2IntOpenHashMap2, object2IntOpenHashMap3) -> {
                object2IntOpenHashMap3.object2IntEntrySet().forEach(entry -> {
                    object2IntOpenHashMap2.put(entry.getKey(), entry.getIntValue());
                });
            });
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public List<D> data() {
            return this.data;
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public List<T> positions() {
            return this.positions;
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public HMM<D, T, S> model() {
            return this.model;
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logForwardProbability(int i, S s) {
            ParamUtils.inRange(this.positionIndexRange, i, "position index");
            return this.logForwardProbabilities[i][validStateIndex(s)];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logForwardProbability(T t, S s) {
            int validPositionIndex = validPositionIndex(t);
            return this.logForwardProbabilities[validPositionIndex][validStateIndex(s)];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logBackwardProbability(int i, S s) {
            ParamUtils.inRange(this.positionIndexRange, i, "position index");
            return this.logBackwardProbabilities[i][validStateIndex(s)];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logBackwardProbability(T t, S s) {
            return this.logBackwardProbabilities[validPositionIndex(t)][validStateIndex(s)];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logProbability(int i, S s) {
            int validStateIndex = validStateIndex(s);
            ParamUtils.inRange(this.positionIndexRange, i, "position index");
            return (this.logBackwardProbabilities[i][validStateIndex] + this.logForwardProbabilities[i][validStateIndex]) - this.logDataLikelihood[i];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logProbability(T t, S s) {
            int intValue = ((Integer) this.stateIndex.getOrDefault(s, -1)).intValue();
            int intValue2 = ((Integer) this.positionIndex.getOrDefault(t, -1)).intValue();
            Utils.validateArg(intValue != -1, "the input state is not recognized by the model");
            Utils.validateArg(intValue2 != -1, "unknown input position");
            return (this.logBackwardProbabilities[intValue2][intValue] + this.logForwardProbabilities[intValue2][intValue]) - this.logDataLikelihood[intValue2];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logProbability(List<S> list) {
            Utils.nonNull(list);
            Utils.validateArg(list.size() == this.data.size(), "the input states sequence does not have the same length as the data sequence");
            if (list.isEmpty()) {
                return 0.0d;
            }
            return logProbability(0, (List) list);
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logProbability(int i, List<S> list) {
            ParamUtils.inRange(this.positionIndexRange, i, "position index");
            Utils.nonNull(list, "the input states sequence cannot be null");
            if (list.isEmpty()) {
                return 0.0d;
            }
            int size = list.size();
            int i2 = (size + i) - 1;
            Utils.validateArg(i2 < this.data.size(), "the input state sequence is too long");
            double logForwardProbability = logForwardProbability((ArrayResult<D, T, S>) this.positions.get(i), (T) list.get(0));
            int i3 = 1;
            for (int i4 = i + 1; i4 <= i2; i4++) {
                logForwardProbability = logForwardProbability + this.model.logTransitionProbability(list.get(i3 - 1), this.positions.get(i4 - 1), list.get(i3), this.positions.get(i4)) + this.model.logEmissionProbability(this.data.get(i4), list.get(i3), this.positions.get(i4));
                i3++;
            }
            return (logForwardProbability + logBackwardProbability((ArrayResult<D, T, S>) this.positions.get(i2), (T) list.get(size - 1))) - this.logDataLikelihood[i2];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logProbability(T t, List<S> list) {
            return logProbability(validPositionIndex(t), (List) list);
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logConstrainedProbability(int i, List<Set<S>> list) {
            ParamUtils.inRange(this.positionIndexRange, i, "position index");
            Utils.nonNull(list, "the input state constraints sequence cannot be null");
            if (list.isEmpty()) {
                return 0.0d;
            }
            int size = (list.size() + i) - 1;
            Utils.validateArg(size < this.data.size(), "the input state sequence is too long");
            ArrayList arrayList = new ArrayList((Collection) Utils.nonNull(list.get(0)));
            Stream stream = arrayList.stream();
            Object2IntMap<S> object2IntMap = this.stateIndex;
            object2IntMap.getClass();
            double[] array = stream.mapToInt(object2IntMap::getInt).mapToDouble(i2 -> {
                return this.logForwardProbabilities[i][i2];
            }).toArray();
            int i3 = 1;
            for (int i4 = i + 1; i4 <= size; i4++) {
                double[] dArr = array;
                ArrayList arrayList2 = arrayList;
                T t = this.positions.get(i4 - 1);
                T t2 = this.positions.get(i4);
                D d = this.data.get(i4);
                arrayList = new ArrayList((Collection) Utils.nonNull(list.get(i3)));
                array = arrayList.stream().mapToDouble(obj -> {
                    return MathUtils.logSumExp(IntStream.range(0, arrayList2.size()).mapToDouble(i5 -> {
                        return dArr[i5] + this.model.logTransitionProbability(arrayList2.get(i5), t, obj, t2);
                    }).toArray()) + this.model.logEmissionProbability(d, obj, t2);
                }).toArray();
                i3++;
            }
            ArrayList arrayList3 = arrayList;
            for (int i5 = 0; i5 < array.length; i5++) {
                double[] dArr2 = array;
                int i6 = i5;
                dArr2[i6] = dArr2[i6] + this.logBackwardProbabilities[size][this.stateIndex.getInt(arrayList3.get(i5))];
            }
            return MathUtils.logSumExp(array) - this.logDataLikelihood[size];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logConstrainedProbability(T t, List<Set<S>> list) {
            return logConstrainedProbability(validPositionIndex(t), list);
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logDataLikelihood() {
            if (this.logDataLikelihood.length == 0) {
                return 0.0d;
            }
            return this.logDataLikelihood[0];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logDataLikelihood(int i) {
            ParamUtils.inRange(this.positionIndexRange, i, "position index");
            return this.logDataLikelihood[i];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logDataLikelihood(T t) {
            return this.logDataLikelihood[validPositionIndex(t)];
        }

        @Override // org.broadinstitute.hellbender.utils.hmm.ForwardBackwardAlgorithm.Result
        public double logChainPosteriorProbability() {
            List<S> hiddenStates = this.model.hiddenStates();
            if (this.positions.isEmpty() || hiddenStates.isEmpty()) {
                return 0.0d;
            }
            return logDataLikelihood() - IntStream.range(0, this.positions.size()).mapToDouble(i -> {
                return IntStream.range(0, hiddenStates.size()).mapToDouble(i -> {
                    return FastMath.exp(logProbability(i, (int) hiddenStates.get(i))) * this.model.logEmissionProbability(this.data.get(i), hiddenStates.get(i), this.positions.get(i));
                }).sum();
            }).sum();
        }

        private int validStateIndex(S s) {
            int intValue = ((Integer) this.stateIndex.getOrDefault(s, -1)).intValue();
            Utils.validateArg(intValue != -1, "the input state is not recognized by the model");
            return intValue;
        }

        private int validPositionIndex(T t) {
            int intValue = ((Integer) this.positionIndex.getOrDefault(t, -1)).intValue();
            Utils.validateArg(intValue != -1, "the input position is not recognized by the model");
            return intValue;
        }
    }

    /* loaded from: input_file:org/broadinstitute/hellbender/utils/hmm/ForwardBackwardAlgorithm$Result.class */
    public interface Result<D, T, S> {
        List<D> data();

        List<T> positions();

        HMM<D, T, S> model();

        double logForwardProbability(int i, S s);

        double logForwardProbability(T t, S s);

        double logBackwardProbability(int i, S s);

        double logBackwardProbability(T t, S s);

        double logProbability(int i, S s);

        double logProbability(T t, S s);

        double logProbability(List<S> list);

        default double logProbability(int i, int i2, S s) {
            ParamUtils.inRange(i2, i, positions().size(), "the 'to' index must be between 'from' and the length of the data/position sequence");
            return logProbability(i, (List) Collections.nCopies(i2 - i, s));
        }

        double logProbability(int i, List<S> list);

        double logProbability(T t, List<S> list);

        double logConstrainedProbability(int i, List<Set<S>> list);

        double logConstrainedProbability(T t, List<Set<S>> list);

        default double logJointProbability(int i, int i2, S s, S s2) {
            ParamUtils.inRange(i, 0, positions().size() - 1, "The 'from' index must be between 0 and the length of the data/position sequence - 1");
            ParamUtils.inRange(i2, i + 1, positions().size() - 1, "the 'to' index must be between 'from' + 1 and the length of the data/position sequence - 1");
            ArrayList arrayList = new ArrayList((i2 - i) + 1);
            arrayList.add(Collections.singleton(s));
            if (i2 > i + 1) {
                arrayList.addAll(Collections.nCopies((i2 - i) - 1, new HashSet(model().hiddenStates())));
            }
            arrayList.add(Collections.singleton(s2));
            return logConstrainedProbability(i, arrayList);
        }

        double logDataLikelihood();

        double logDataLikelihood(int i);

        double logDataLikelihood(T t);

        double logChainPosteriorProbability();
    }

    public static <D, T, S> Result<D, T, S> apply(List<D> list, List<T> list2, HMM<D, T, S> hmm) {
        Utils.nonNull(list, "the input data sequence cannot be null.");
        Utils.nonNull(list2, "the input position sequence cannot be null.");
        Utils.nonNull(hmm, "the input model cannot be null");
        List unmodifiableList = Collections.unmodifiableList(new ArrayList(list));
        List unmodifiableList2 = Collections.unmodifiableList(new ArrayList(list2));
        Utils.validateArg(unmodifiableList.size() == unmodifiableList2.size(), "the data sequence and position sequence must have the same number of elements");
        return new ArrayResult(unmodifiableList, unmodifiableList2, hmm, calculateLogForwardProbabilities(hmm, unmodifiableList, unmodifiableList2), calculateLogBackwardProbabilities(hmm, unmodifiableList, unmodifiableList2));
    }

    private static <D, T, S> double[][] calculateLogForwardProbabilities(HMM<D, T, S> hmm, List<D> list, List<T> list2) {
        List<S> hiddenStates = hmm.hiddenStates();
        int size = hiddenStates.size();
        int size2 = list.size();
        double[][] dArr = new double[size2][size];
        if (size2 == 0) {
            return dArr;
        }
        T t = list2.get(0);
        D d = list.get(0);
        for (int i = 0; i < hiddenStates.size(); i++) {
            S s = hiddenStates.get(i);
            dArr[0][i] = hmm.logPriorProbability(s, t) + hmm.logEmissionProbability(d, s, t);
        }
        double[] dArr2 = new double[hiddenStates.size()];
        for (int i2 = 1; i2 < list.size(); i2++) {
            int i3 = i2 - 1;
            T t2 = list2.get(i3);
            T t3 = list2.get(i2);
            for (int i4 = 0; i4 < size; i4++) {
                S s2 = hiddenStates.get(i4);
                for (int i5 = 0; i5 < size; i5++) {
                    dArr2[i5] = dArr[i3][i5] + hmm.logTransitionProbability(hiddenStates.get(i5), t2, s2, t3);
                }
                dArr[i2][i4] = MathUtils.logSumExp(dArr2) + hmm.logEmissionProbability(list.get(i2), s2, t3);
            }
        }
        return dArr;
    }

    private static <D, T, S> double[][] calculateLogBackwardProbabilities(HMM<D, T, S> hmm, List<D> list, List<T> list2) {
        List<S> hiddenStates = hmm.hiddenStates();
        int size = hiddenStates.size();
        int size2 = list.size();
        double[][] dArr = new double[size2][size];
        if (size2 == 0) {
            return dArr;
        }
        double[] dArr2 = new double[hiddenStates.size()];
        for (int i = size2 - 2; i >= 0; i--) {
            int i2 = i + 1;
            T t = list2.get(i);
            T t2 = list2.get(i2);
            for (int i3 = 0; i3 < size; i3++) {
                S s = hiddenStates.get(i3);
                for (int i4 = 0; i4 < size; i4++) {
                    dArr2[i4] = dArr[i2][i4] + hmm.logTransitionProbability(s, t, hiddenStates.get(i4), t2) + hmm.logEmissionProbability(list.get(i2), hiddenStates.get(i4), t2);
                }
                dArr[i][i3] = MathUtils.logSumExp(dArr2);
            }
        }
        return dArr;
    }
}
