package org.maochen.nlp.ml.classifier.hmm;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.maochen.nlp.ml.SequenceTuple;
import org.maochen.nlp.ml.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/ml/classifier/hmm/HMM.class */
public class HMM {
    private static final Logger LOG = LoggerFactory.getLogger(HMM.class);
    public static final int WORD_INDEX = 0;
    protected static final String START = "<START>";
    protected static final String END = "<END>";

    private static SequenceTuple getSequenceTuple(List<String> list, List<String> list2) {
        HashMap hashMap = new HashMap();
        hashMap.put(0, list);
        return new SequenceTuple(hashMap, list2);
    }

    public static List<SequenceTuple> readTrainFile(String str, String str2, int i, int i2) {
        ArrayList arrayList = new ArrayList();
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            Throwable th = null;
            try {
                try {
                    ArrayList arrayList2 = new ArrayList();
                    ArrayList arrayList3 = new ArrayList();
                    for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                        if (!readLine.trim().isEmpty()) {
                            String[] split = readLine.split(str2);
                            arrayList2.add(WordUtils.normalizeWord(split[i]));
                            arrayList3.add(WordUtils.normalizeTag(split[i2]));
                        } else if (!arrayList2.isEmpty() && !arrayList3.isEmpty()) {
                            arrayList.add(getSequenceTuple(arrayList2, arrayList3));
                            arrayList2 = new ArrayList();
                            arrayList3 = new ArrayList();
                        }
                    }
                    if (bufferedReader != null) {
                        if (0 != 0) {
                            try {
                                bufferedReader.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            bufferedReader.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return arrayList;
    }

    public static void normalizeEmission(HMMModel hMMModel) {
        hMMModel.emission.columnMap().values().parallelStream().map(map -> {
            double sum = map.values().stream().mapToDouble(d -> {
                return d.doubleValue();
            }).sum();
            for (String str : map.keySet()) {
                map.put(str, Double.valueOf(((Double) map.get(str)).doubleValue() / sum));
            }
            return null;
        }).collect(Collectors.toSet());
        for (String str : hMMModel.emission.columnKeySet()) {
            hMMModel.emissionMin.put(str, Double.valueOf(((Double) hMMModel.emission.column(str).values().stream().min((v0, v1) -> {
                return v0.compareTo(v1);
            }).orElse(Double.valueOf(0.0d))).doubleValue()));
        }
    }

    public static void normalizeTrans(HMMModel hMMModel) {
        hMMModel.transition.rowMap().values().parallelStream().map(map -> {
            double sum = map.values().stream().mapToDouble(d -> {
                return d.doubleValue();
            }).sum();
            for (String str : map.keySet()) {
                map.put(str, Double.valueOf(((Double) map.get(str)).doubleValue() / sum));
            }
            return null;
        }).collect(Collectors.toSet());
    }

    public static HMMModel train(List<SequenceTuple> list) {
        HMMModel hMMModel = new HMMModel();
        for (SequenceTuple sequenceTuple : list) {
            List list2 = (List) sequenceTuple.entries.stream().map(tuple -> {
                return tuple.vector.featsName[0];
            }).collect(Collectors.toList());
            List label = sequenceTuple.getLabel();
            list2.add(0, START);
            list2.add(END);
            label.add(0, START);
            label.add(END);
            for (int i = 0; i < list2.size(); i++) {
                Double d = (Double) hMMModel.emission.get(list2.get(i), label.get(i));
                hMMModel.emission.put(list2.get(i), label.get(i), Double.valueOf(d == null ? 1.0d : d.doubleValue() + 1.0d));
            }
            for (int i2 = 0; i2 < sequenceTuple.entries.size() - 1; i2++) {
                Double d2 = (Double) hMMModel.transition.get(label.get(i2), label.get(i2 + 1));
                hMMModel.transition.put(label.get(i2), label.get(i2 + 1), Double.valueOf(d2 == null ? 1.0d : d2.doubleValue() + 1.0d));
            }
        }
        normalizeEmission(hMMModel);
        normalizeTrans(hMMModel);
        return hMMModel;
    }

    public static List<String> viterbi(HMMModel hMMModel, String[] strArr) {
        return Viterbi.resolve(hMMModel, strArr);
    }

    public static void eval(HMMModel hMMModel, String str, String str2, int i, int i2) {
        int i3 = 0;
        int i4 = 0;
        for (SequenceTuple sequenceTuple : readTrainFile(str, str2, i, i2)) {
            String[] strArr = (String[]) sequenceTuple.entries.stream().map(tuple -> {
                return tuple.vector.featsName[0];
            }).toArray(i5 -> {
                return new String[i5];
            });
            List<String> viterbi = viterbi(hMMModel, strArr);
            for (int i6 = 0; i6 < viterbi.size(); i6++) {
                i3++;
                String normalizeTag = WordUtils.normalizeTag(((Tuple) sequenceTuple.entries.get(i6)).label);
                String normalizeTag2 = WordUtils.normalizeTag(viterbi.get(i6));
                if (!normalizeTag2.startsWith(normalizeTag) && !normalizeTag.startsWith(normalizeTag2)) {
                    System.out.println(strArr[i6] + " exp: " + normalizeTag + " actual: " + viterbi.get(i6));
                    i4++;
                }
            }
        }
        System.out.println("accurancy: " + i4 + "/" + i3 + " -> " + String.format("%.2f", Double.valueOf((1.0d - (i4 / i3)) * 100.0d)) + "%");
    }

    public static HMMModel loadModel(String str) {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(str));
            Throwable th = null;
            try {
                try {
                    HMMModel hMMModel = (HMMModel) objectInputStream.readObject();
                    if (objectInputStream != null) {
                        if (0 != 0) {
                            try {
                                objectInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            objectInputStream.close();
                        }
                    }
                    return hMMModel;
                } finally {
                }
            } finally {
            }
        } catch (IOException | ClassNotFoundException e) {
            LOG.error("Load model err.", e);
            return null;
        }
    }

    public static void saveModel(String str, HMMModel hMMModel) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
            Throwable th = null;
            try {
                objectOutputStream.writeObject(hMMModel);
                if (objectOutputStream != null) {
                    if (0 != 0) {
                        try {
                            objectOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        objectOutputStream.close();
                    }
                }
            } finally {
            }
        } catch (IOException e) {
            LOG.error("Persist model err.", e);
        }
    }

    public static void main(String[] strArr) throws InterruptedException {
        List<SequenceTuple> readTrainFile = readTrainFile("/Users/mguan/Dropbox/Course/Natural Lang Processing/HW/HW4_POSTagger_HMM/Homework4_corpus/POSData/development.pos", "\t", 0, 1);
        readTrainFile.addAll(readTrainFile("/Users/mguan/Dropbox/Course/Natural Lang Processing/HW/HW4_POSTagger_HMM/Homework4_corpus/POSData/training.pos", "\t", 0, 1));
        HMMModel train = train(readTrainFile);
        eval(train, "/Users/mguan/Dropbox/Course/Natural Lang Processing/HW/HW4_POSTagger_HMM/Homework4_corpus/POSData/training.pos", "\t", 0, 1);
        List<String> viterbi = viterbi(train, "The quick brown fox jumped over the lazy dog .".split("\\s"));
        System.out.println("The quick brown fox jumped over the lazy dog .");
        System.out.println(viterbi);
    }
}
