package org.maochen.nlp.ml.classifier.crfsuite;

import com.github.jcrfsuite.util.CrfSuiteLoader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.ml.ISeqClassifier;
import org.maochen.nlp.ml.SequenceTuple;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.vector.FeatNamedVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import third_party.org.chokkan.crfsuite.Attribute;
import third_party.org.chokkan.crfsuite.Item;
import third_party.org.chokkan.crfsuite.ItemSequence;
import third_party.org.chokkan.crfsuite.StringList;
import third_party.org.chokkan.crfsuite.Tagger;
import third_party.org.chokkan.crfsuite.Trainer;

/* loaded from: input_file:org/maochen/nlp/ml/classifier/crfsuite/CRFClassifier.class */
public class CRFClassifier implements ISeqClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(CRFClassifier.class);
    private Properties props = new Properties();
    private String modelPath = null;
    private Tagger tagger = null;
    private static final String DEFAULT_ALGORITHM = "lbfgs";
    private static final String DEFAULT_GRAPHICAL_MODEL_TYPE = "crf1d";

    private static Pair<List<ItemSequence>, List<StringList>> loadTrainingData(List<SequenceTuple> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (SequenceTuple sequenceTuple : list) {
            arrayList.add(getXseqForOneSeqTuple(sequenceTuple));
            StringList stringList = new StringList();
            Stream stream = sequenceTuple.getLabel().stream();
            stringList.getClass();
            stream.forEach(stringList::add);
            arrayList2.add(stringList);
        }
        return new ImmutablePair(arrayList, arrayList2);
    }

    private static ItemSequence getXseqForOneSeqTuple(SequenceTuple sequenceTuple) {
        ItemSequence itemSequence = new ItemSequence();
        for (Tuple tuple : sequenceTuple.entries) {
            Item item = new Item();
            for (int i = 0; i < tuple.vector.getVector().length; i++) {
                item.add(tuple.vector instanceof FeatNamedVector ? new Attribute(tuple.vector.featsName[i]) : new Attribute(String.valueOf(i), tuple.vector.getVector()[i]));
            }
            itemSequence.add(item);
        }
        return itemSequence;
    }

    public ISeqClassifier train(List<SequenceTuple> list) {
        if (list == null || list.size() == 0) {
            LOG.warn("Training data is empty.");
            return this;
        }
        if (this.modelPath == null) {
            try {
                this.modelPath = Files.createTempDirectory("crfsuite", new FileAttribute[0]).toAbsolutePath().toString();
            } catch (IOException e) {
                LOG.error("Create temp directory failed.", e);
                e.printStackTrace();
            }
        }
        Pair<List<ItemSequence>, List<StringList>> loadTrainingData = loadTrainingData(list);
        Trainer trainer = new Trainer();
        String str = (String) this.props.getOrDefault("algorithm", DEFAULT_ALGORITHM);
        this.props.remove("algorithm");
        String str2 = (String) this.props.getOrDefault("graphicalModelType", DEFAULT_GRAPHICAL_MODEL_TYPE);
        this.props.remove("graphicalModelType");
        trainer.select(str, str2);
        this.props.entrySet().forEach(entry -> {
            trainer.set((String) entry.getKey(), (String) entry.getValue());
        });
        for (int i = 0; i < list.size(); i++) {
            trainer.append((ItemSequence) ((List) loadTrainingData.getLeft()).get(i), (StringList) ((List) loadTrainingData.getRight()).get(i), 0);
        }
        trainer.train(this.modelPath, -1);
        return this;
    }

    public synchronized List<Pair<String, Double>> predict(SequenceTuple sequenceTuple) {
        if (this.tagger == null) {
            loadModel(null);
        }
        ArrayList arrayList = new ArrayList();
        this.tagger.set(getXseqForOneSeqTuple(sequenceTuple));
        StringList viterbi = this.tagger.viterbi();
        for (int i = 0; i < viterbi.size(); i++) {
            String str = viterbi.get(i);
            arrayList.add(new ImmutablePair(str, Double.valueOf(this.tagger.marginal(str, i))));
        }
        return arrayList;
    }

    public void setParameter(Properties properties) {
        this.modelPath = (String) properties.getOrDefault("model", null);
        properties.remove("model");
        this.props = properties;
    }

    public void persistModel(String str) throws IOException {
        if (this.modelPath.equals(str)) {
            throw new IOException("same as original model path.");
        }
        Files.copy(new File(this.modelPath).toPath(), new File(str).toPath(), new CopyOption[0]);
    }

    public Pair<Integer, Integer> validate(List<SequenceTuple> list) {
        int sum = list.stream().mapToInt(sequenceTuple -> {
            return sequenceTuple.entries.size();
        }).sum();
        int i = 0;
        for (SequenceTuple sequenceTuple2 : list) {
            List list2 = (List) predict(sequenceTuple2).stream().map((v0) -> {
                return v0.getLeft();
            }).collect(Collectors.toList());
            List label = sequenceTuple2.getLabel();
            if (list2.size() != label.size()) {
                throw new RuntimeException("Actual size: " + list2.size() + "\tExpected size: " + label.size());
            }
            for (int i2 = 0; i2 < list2.size(); i2++) {
                if (!((String) list2.get(i2)).equals(label.get(i2))) {
                    i++;
                }
            }
        }
        System.out.println("Err/Total: " + i + "/" + sum);
        System.out.println("Accuracy: " + ((1.0d - (i / sum)) * 100.0d) + "%");
        return new ImmutablePair(Integer.valueOf(i), Integer.valueOf(sum));
    }

    public void loadModel(InputStream inputStream) {
        if (this.modelPath == null) {
            throw new IllegalArgumentException("Please set model path parameter to load model");
        }
        this.tagger = new Tagger();
        if (this.tagger.open(this.modelPath)) {
            return;
        }
        LOG.error("Unable load model: " + this.modelPath);
    }

    public CRFClassifier() {
    }

    public CRFClassifier(Properties properties) {
        setParameter(properties);
    }

    static {
        try {
            CrfSuiteLoader.load();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
