package org.maochen.nlp.ml.classifier.hmm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/ml/classifier/hmm/Viterbi.class */
public class Viterbi {
    private static final Logger LOG = LoggerFactory.getLogger(Viterbi.class);
    private static Predicate<String> isPunct = str -> {
        return Pattern.compile("\\p{Punct}+").matcher(str).find();
    };

    public static List<String> resolve(HMMModel hMMModel, String[] strArr) {
        HashSet hashSet = new HashSet();
        for (String str : strArr) {
            hashSet.addAll(hMMModel.emission.row(str).keySet());
        }
        ArrayList arrayList = new ArrayList(hashSet);
        arrayList.add(0, "<START>");
        arrayList.add("<END>");
        List list = (List) Arrays.stream(strArr).collect(Collectors.toList());
        list.add(0, "<START>");
        list.add("<END>");
        double[][] dArr = new double[arrayList.size()][list.size()];
        int[] iArr = new int[list.size()];
        dArr[0][0] = 1.0d;
        for (int i = 1; i < dArr[0].length; i++) {
            String str2 = (String) list.get(i);
            for (int i2 = 1; i2 < dArr.length; i2++) {
                String str3 = (String) arrayList.get(i2);
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    if (dArr[i3][i - 1] != 0.0d) {
                        Double d = (Double) hMMModel.transition.get((String) arrayList.get(i3), str3);
                        Double valueOf = Double.valueOf(d == null ? 0.0d : d.doubleValue());
                        Double d2 = (Double) hMMModel.emission.get(str2, str3);
                        if (d2 == null) {
                            if (!hMMModel.emission.rowKeySet().contains(str2)) {
                                LOG.debug("Missing word: " + str2);
                                d2 = (!isPunct.test(str3) || isPunct.test(str2)) ? hMMModel.emissionMin.get(str3) : Double.valueOf(0.0d);
                            }
                            if (d2 == null) {
                                d2 = Double.valueOf(0.0d);
                            }
                        }
                        double doubleValue = dArr[i3][i - 1] * valueOf.doubleValue() * d2.doubleValue();
                        if (doubleValue > dArr[i2][i]) {
                            dArr[i2][i] = doubleValue;
                            iArr[i - 1] = i3;
                        }
                    }
                }
            }
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Path: " + ((String) Arrays.stream(iArr).filter((v0) -> {
                return Objects.nonNull(v0);
            }).mapToObj(String::valueOf).reduce((str4, str5) -> {
                return str4 + " " + str5;
            }).orElse(null)));
            StringBuilder sb = new StringBuilder();
            sb.append("\t\t").append((String) list.stream().reduce((str6, str7) -> {
                return str6 + "\t" + str7;
            }).orElse(null)).append(System.lineSeparator());
            for (int i4 = 0; i4 < dArr.length; i4++) {
                for (int i5 = 0; i5 < dArr[i4].length; i5++) {
                    if (i5 == 0) {
                        sb.append((String) arrayList.get(i4)).append("\t");
                    }
                    sb.append(dArr[i4][i5]).append("\t");
                }
                sb.append(System.lineSeparator());
            }
            Stream stream = Arrays.stream(sb.toString().split(System.lineSeparator()));
            Logger logger = LOG;
            logger.getClass();
            stream.forEach(logger::debug);
        }
        IntStream filter = Arrays.stream(iArr).filter((v0) -> {
            return Objects.nonNull(v0);
        }).filter(i6 -> {
            return i6 != 0;
        });
        arrayList.getClass();
        List<String> list2 = (List) filter.mapToObj(arrayList::get).collect(Collectors.toList());
        LOG.debug(list2.toString());
        return list2;
    }
}
