package org.maochen.nlp.ml.classifier.maxent;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import opennlp.maxent.GISModel;
import opennlp.maxent.io.GISModelReader;
import opennlp.maxent.io.PlainTextGISModelWriter;
import opennlp.model.PlainTextFileDataReader;
import opennlp.model.Prior;
import opennlp.model.RealValueFileEventStream;
import opennlp.model.UniformPrior;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.IClassifier;
import org.maochen.nlp.ml.classifier.maxent.eventstream.EventStream;
import org.maochen.nlp.ml.classifier.maxent.eventstream.StringEventStream;
import org.maochen.nlp.ml.classifier.maxent.eventstream.TupleEventStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/ml/classifier/maxent/MaxEntClassifier.class */
public class MaxEntClassifier implements IClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(MaxEntClassifier.class);
    private boolean useSmoothing = true;
    private int iterations = 100;
    private int cutoff = 0;
    private int nthreads = Runtime.getRuntime().availableProcessors();
    private double smoothingObservation = 0.1d;
    private GISModel model = null;

    public MaxEntClassifier trainString(List<String[]> list) {
        return train(new StringEventStream(list));
    }

    private MaxEntClassifier train(EventStream eventStream) {
        Prior uniformPrior = new UniformPrior();
        OnePassRealValueDataIndexer onePassRealValueDataIndexer = new OnePassRealValueDataIndexer(eventStream, this.cutoff, true);
        GISTrainer gISTrainer = new GISTrainer();
        gISTrainer.setSmoothing(this.useSmoothing);
        gISTrainer.setSmoothingObservation(this.smoothingObservation);
        this.model = gISTrainer.trainModel(this.iterations, onePassRealValueDataIndexer, uniformPrior, this.cutoff, this.nthreads);
        return this;
    }

    public Map<String, Double> predict(String[] strArr) {
        Tuple tuple = new Tuple((double[]) null);
        float[] parseContexts = RealValueFileEventStream.parseContexts(strArr);
        tuple.featureVector = new double[parseContexts.length];
        for (int i = 0; i < parseContexts.length; i++) {
            tuple.featureVector[i] = parseContexts[i];
        }
        tuple.featureName = strArr;
        return predict(tuple);
    }

    @Override // org.maochen.nlp.ml.classifier.IClassifier
    public IClassifier train(List<Tuple> list) {
        return train(new TupleEventStream(list));
    }

    @Override // org.maochen.nlp.ml.classifier.IClassifier
    public Map<String, Double> predict(Tuple tuple) {
        float[] fArr = new float[tuple.featureVector.length];
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) tuple.featureVector[i];
        }
        double[] eval = this.model.eval(tuple.featureName, fArr, new double[this.model.getNumOutcomes()]);
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < eval.length; i2++) {
            hashMap.put(this.model.getOutcome(i2), Double.valueOf(eval[i2]));
        }
        return hashMap;
    }

    @Override // org.maochen.nlp.ml.classifier.IClassifier
    public void setParameter(Map<String, String> map) {
        if (map == null) {
            return;
        }
        if (map.containsKey("use_smoothing")) {
            this.useSmoothing = Boolean.valueOf(map.get("use_smoothing")).booleanValue();
        }
        if (map.containsKey("iterations")) {
            this.iterations = Integer.parseInt(map.get("iterations"));
        }
        if (map.containsKey("cutoff")) {
            this.cutoff = Integer.parseInt(map.get("cutoff"));
        }
        if (map.containsKey("nthreads")) {
            this.nthreads = Integer.parseInt(map.get("nthreads"));
        }
        if (map.containsKey("smoothing_observation")) {
            this.smoothingObservation = Double.parseDouble(map.get("smoothing_observation"));
        }
    }

    @Override // org.maochen.nlp.ml.classifier.IClassifier
    public void persistModel(String str) throws IOException {
        new PlainTextGISModelWriter(this.model, new File(str)).persist();
    }

    @Override // org.maochen.nlp.ml.classifier.IClassifier
    public void loadModel(InputStream inputStream) {
        LOG.info("Loading MaxEnt model.");
        try {
            this.model = new GISModelReader(new PlainTextFileDataReader(inputStream)).getModel();
        } catch (IOException e) {
            LOG.error("model load err.", e);
        }
    }
}
