package org.maochen.nlp.ml.classifier.naivebayes;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.NotImplementedException;
import org.maochen.nlp.ml.IClassifier;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.vector.DenseVector;
import org.maochen.nlp.utils.VectorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/ml/classifier/naivebayes/NaiveBayesClassifier.class */
public class NaiveBayesClassifier implements IClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(NaiveBayesClassifier.class);
    private NaiveBayesModel model;

    public NaiveBayesClassifier(InputStream inputStream) {
        this.model = new NaiveBayesModel();
        this.model.load(inputStream);
    }

    public NaiveBayesClassifier() {
    }

    public void setParameter(Map<String, String> map) {
        throw new NotImplementedException("No parameter needed for Naive Bayes.");
    }

    public Map<String, Double> predict(Tuple tuple) {
        HashMap hashMap = new HashMap();
        for (Integer num : this.model.labelIndexer.getIndexSet()) {
            double d = 1.0d;
            for (int i = 0; i < tuple.vector.getVector().length; i++) {
                d *= VectorUtils.gaussianPDF(this.model.meanVectors[num.intValue()][i], this.model.varianceVectors[num.intValue()][i], tuple.vector.getVector()[i]);
            }
            hashMap.put(num, Double.valueOf(this.model.labelPrior.get(num).doubleValue() * d));
        }
        double doubleValue = ((Double) hashMap.values().stream().reduce((d2, d3) -> {
            return Double.valueOf(d2.doubleValue() + d3.doubleValue());
        }).orElse(Double.valueOf(-1.0d))).doubleValue();
        if (doubleValue == -1.0d) {
            LOG.error("Evidence is Empty!");
            return new HashMap();
        }
        hashMap.entrySet().forEach(entry -> {
            hashMap.put(entry.getKey(), Double.valueOf(((Double) entry.getValue()).doubleValue() / doubleValue));
        });
        Map<String, Double> convertMapKey = this.model.labelIndexer.convertMapKey(hashMap);
        if (tuple.label == null || tuple.label.isEmpty()) {
            tuple.label = (String) convertMapKey.entrySet().stream().max((entry2, entry3) -> {
                return ((Double) entry2.getValue()).compareTo((Double) entry3.getValue());
            }).map((v0) -> {
                return v0.getKey();
            }).orElse("");
        }
        return convertMapKey;
    }

    public String predictLabel(Tuple tuple) {
        return (String) predict(tuple).entrySet().stream().max((entry, entry2) -> {
            return ((Double) entry.getValue()).compareTo((Double) entry2.getValue());
        }).map((v0) -> {
            return v0.getKey();
        }).orElse(null);
    }

    public IClassifier train(List<Tuple> list) {
        this.model = new NBTrainingEngine(list).train();
        return this;
    }

    public void persistModel(String str) {
        if (this.model != null) {
            this.model.persist(str);
        }
    }

    public void loadModel(InputStream inputStream) {
        this.model = new NaiveBayesModel();
        this.model.load(inputStream);
    }

    public static List<Tuple> readTrainingData(String str, String str2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            Throwable th = null;
            try {
                try {
                    for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                        arrayList2.add(readLine.trim());
                    }
                    if (bufferedReader != null) {
                        if (0 != 0) {
                            try {
                                bufferedReader.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            bufferedReader.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        Iterator it = ((List) arrayList2.parallelStream().distinct().collect(Collectors.toList())).iterator();
        while (it.hasNext()) {
            String[] split = ((String) it.next()).trim().split(str2);
            String str3 = split[0];
            double[] dArr = new double[split.length - 1];
            for (int i = 1; i < split.length; i++) {
                dArr[i - 1] = Double.parseDouble(split[i].contains(":") ? split[i].split(":")[1] : split[i]);
            }
            arrayList.add(new Tuple(0, new DenseVector(dArr), str3));
        }
        return arrayList;
    }

    public static void writeToFile(List<Tuple> list, String str) {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str)));
            Throwable th = null;
            try {
                try {
                    Iterator<Tuple> it = list.iterator();
                    while (it.hasNext()) {
                        bufferedWriter.write(it.next().toString() + System.lineSeparator());
                    }
                    if (bufferedWriter != null) {
                        if (0 != 0) {
                            try {
                                bufferedWriter.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            bufferedWriter.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void splitData(String str) {
        int size;
        List<Tuple> readTrainingData = readTrainingData(str, "\\s");
        ArrayList arrayList = new ArrayList();
        int i = 0;
        do {
            i++;
            System.out.println("Iteration:\t" + i);
            size = readTrainingData.size();
            NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
            naiveBayesClassifier.train(readTrainingData);
            Iterator<Tuple> it = readTrainingData.iterator();
            while (it.hasNext()) {
                Tuple next = it.next();
                if (!next.label.equals(naiveBayesClassifier.predictLabel(next)) && !next.label.equals("1")) {
                    arrayList.add(next);
                    it.remove();
                }
            }
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                Tuple tuple = (Tuple) it2.next();
                if (tuple.label.equals(naiveBayesClassifier.predictLabel(tuple))) {
                    readTrainingData.add(tuple);
                    it2.remove();
                }
            }
        } while (readTrainingData.size() != size);
        writeToFile(readTrainingData, str + ".aligned");
        writeToFile(arrayList, str + ".wrong");
    }
}
