package org.maochen.nlp.sentencetype;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.maxent.MaxEntClassifier;
import org.maochen.nlp.ml.vector.DenseVector;
import org.maochen.nlp.parser.DTree;
import org.maochen.nlp.parser.IParser;
import org.maochen.nlp.parser.stanford.nn.StanfordNNDepParser;
import org.maochen.nlp.parser.stanford.pcfg.StanfordPCFGParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/sentencetype/SentenceTypeClassifier.class */
public class SentenceTypeClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(SentenceTypeClassifier.class);
    private MaxEntClassifier maxEntClassifier;
    private FeatureExtractor featureExtractor;
    private IParser parser;

    public void train(String str) throws IOException {
        this.maxEntClassifier.setParameter(new HashMap<String, String>() { // from class: org.maochen.nlp.sentencetype.SentenceTypeClassifier.1
            {
                put("iterations", "120");
            }
        });
        this.parser.parse(".");
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        HashSet hashSet = new HashSet();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        Throwable th = null;
        try {
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                hashSet.add(readLine);
            }
            LOG.info("Loaded Training data.");
            LOG.info("Generating parse tree.");
            hashSet.parallelStream().map(str2 -> {
                String str2 = str2.split("\\t")[1];
                concurrentHashMap.put(str2, this.parser.parse(str2));
                return null;
            }).collect(Collectors.toSet());
            LOG.info("Generating feats");
            List<Tuple> list = (List) hashSet.stream().map(str3 -> {
                String str3 = str3.split("\\t")[1];
                return new Tuple(1, new DenseVector(this.featureExtractor.generateFeats(str3, (DTree) concurrentHashMap.get(str3)).stream().mapToDouble(str4 -> {
                    return 1.0d;
                }).toArray()), str3.split("\\t")[0]);
            }).collect(Collectors.toList());
            LOG.info("Extracted Feats.");
            this.maxEntClassifier.train(list);
        } finally {
            if (bufferedReader != null) {
                if (0 != 0) {
                    try {
                        bufferedReader.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    bufferedReader.close();
                }
            }
        }
    }

    public void persist(String str) throws IOException {
        this.maxEntClassifier.persistModel(str);
    }

    public void loadModel(InputStream inputStream) throws IOException {
        this.maxEntClassifier.loadModel(inputStream);
    }

    public Map<String, Double> predict(String str, DTree dTree) {
        return this.maxEntClassifier.predict(new Tuple(new DenseVector(this.featureExtractor.generateFeats(str, dTree).stream().mapToDouble(str2 -> {
            return 1.0d;
        }).toArray())));
    }

    public Map<String, Double> predict(String str) {
        return predict(str, this.parser.parse(str));
    }

    public SentenceTypeClassifier() {
        this(new StanfordNNDepParser());
    }

    public SentenceTypeClassifier(IParser iParser) {
        this.maxEntClassifier = new MaxEntClassifier();
        this.featureExtractor = new FeatureExtractor();
        this.parser = iParser;
    }

    public static void main(String[] strArr) throws IOException {
        String str = "/Users/Maochen/Desktop//sent_type_model.dat";
        SentenceTypeClassifier sentenceTypeClassifier = new SentenceTypeClassifier(new StanfordPCFGParser());
        sentenceTypeClassifier.train("/Users/Maochen/workspace/nlp-service_training-data/sentence_type_corpus.txt");
        sentenceTypeClassifier.persist(str);
        sentenceTypeClassifier.loadModel(new FileInputStream(str));
        Scanner scanner = new Scanner(System.in);
        System.out.println("Input Sentence:");
        while (true) {
            String nextLine = scanner.nextLine();
            if (nextLine.equalsIgnoreCase("exit")) {
                scanner.close();
                System.exit(0);
                return;
            } else {
                Map<String, Double> predict = sentenceTypeClassifier.predict(nextLine);
                System.out.println(predict);
                System.out.println(StringUtils.capitalize((String) predict.entrySet().stream().max((entry, entry2) -> {
                    return ((Double) entry.getValue()).compareTo((Double) entry2.getValue());
                }).map((v0) -> {
                    return v0.getKey();
                }).orElse(null)));
            }
        }
    }
}
