package ai.idylnlp.models.opennlp.training;

import ai.idylnlp.model.nlp.subjects.SubjectOfTrainingOrEvaluation;
import ai.idylnlp.model.training.AccuracyEvaluationResult;
import ai.idylnlp.models.ModelOperationsUtils;
import ai.idylnlp.models.opennlp.training.model.ModelSeparateDataValidationOperations;
import ai.idylnlp.models.opennlp.training.model.ModelTrainingOperations;
import ai.idylnlp.models.opennlp.training.model.TrainingAlgorithm;
import ai.idylnlp.opennlp.custom.encryption.OpenNLPEncryptionFactory;
import ai.idylnlp.training.definition.model.TrainingDefinitionReader;
import com.neovisionaries.i18n.LanguageCode;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import opennlp.tools.cmdline.postag.POSEvaluationErrorListener;
import opennlp.tools.cmdline.postag.POSTaggerFineGrainedReportListener;
import opennlp.tools.postag.POSEvaluator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSTaggerEvaluationMonitor;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.POSTaggerME;
import opennlp.tools.postag.WordTagSampleStream;
import opennlp.tools.util.MarkableFileInputStreamFactory;
import opennlp.tools.util.PlainTextByLineStream;
import opennlp.tools.util.TrainingParameters;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:ai/idylnlp/models/opennlp/training/PartOfSpeechModelOperations.class */
public class PartOfSpeechModelOperations implements ModelTrainingOperations, ModelSeparateDataValidationOperations<AccuracyEvaluationResult> {
    private static final Logger LOGGER = LogManager.getLogger(PartOfSpeechModelOperations.class);

    public static String train(TrainingDefinitionReader trainingDefinitionReader) throws IOException {
        PartOfSpeechModelOperations partOfSpeechModelOperations = new PartOfSpeechModelOperations();
        SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation = ModelOperationsUtils.getSubjectOfTrainingOrEvaluation(trainingDefinitionReader);
        String file = trainingDefinitionReader.getTrainingDefinition().getModel().getFile();
        String language = trainingDefinitionReader.getTrainingDefinition().getModel().getLanguage();
        String encryptionkey = trainingDefinitionReader.getTrainingDefinition().getModel().getEncryptionkey();
        int intValue = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getCutoff().intValue();
        int intValue2 = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getIterations().intValue();
        int intValue3 = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getThreads().intValue();
        String name = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getName();
        LanguageCode byCodeIgnoreCase = LanguageCode.getByCodeIgnoreCase(language);
        if (name.equalsIgnoreCase(TrainingAlgorithm.PERCEPTRON.getName())) {
            return partOfSpeechModelOperations.trainPerceptron(subjectOfTrainingOrEvaluation, file, byCodeIgnoreCase, encryptionkey, intValue, intValue2);
        }
        if (name.equalsIgnoreCase(TrainingAlgorithm.MAXENT_QN.getName())) {
            return partOfSpeechModelOperations.trainMaxEntQN(subjectOfTrainingOrEvaluation, file, byCodeIgnoreCase, encryptionkey, intValue, intValue2, intValue3, trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getL1().doubleValue(), trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getL2().doubleValue(), trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getM().intValue(), trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getMax().intValue());
        }
        throw new IOException("Invalid algorithm specified in the training definition file: " + name);
    }

    @Override // ai.idylnlp.models.opennlp.training.model.ModelSeparateDataValidationOperations
    public AccuracyEvaluationResult separateDataEvaluate(SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation, String str, String str2) throws IOException {
        LOGGER.info("Doing model evaluation using separate training data.");
        OpenNLPEncryptionFactory.getDefault().setKey(str2);
        WordTagSampleStream wordTagSampleStream = new WordTagSampleStream(new PlainTextByLineStream(new MarkableFileInputStreamFactory(new File(subjectOfTrainingOrEvaluation.getInputFile())), "UTF-8"));
        POSEvaluator pOSEvaluator = new POSEvaluator(new POSTaggerME(new POSModel(new File(str))), new POSTaggerEvaluationMonitor[]{new POSEvaluationErrorListener(), new POSTaggerFineGrainedReportListener(System.out)});
        pOSEvaluator.evaluate(wordTagSampleStream);
        OpenNLPEncryptionFactory.getDefault().clearKey();
        return new AccuracyEvaluationResult(pOSEvaluator.getWordAccuracy(), pOSEvaluator.getWordCount());
    }

    @Override // ai.idylnlp.models.opennlp.training.model.ModelTrainingOperations
    public String trainMaxEntQN(SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation, String str, LanguageCode languageCode, String str2, int i, int i2, int i3, double d, double d2, int i4, int i5) throws IOException {
        LOGGER.info("Beginning parts-of-speech model training. Output model will be: " + str);
        PlainTextByLineStream plainTextByLineStream = new PlainTextByLineStream(new MarkableFileInputStreamFactory(new File(subjectOfTrainingOrEvaluation.getInputFile())), "UTF-8");
        WordTagSampleStream wordTagSampleStream = new WordTagSampleStream(plainTextByLineStream);
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", Integer.toString(i));
        trainingParameters.put("Iterations", Integer.toString(i2));
        trainingParameters.put("Algorithm", TrainingAlgorithm.MAXENT_QN.getAlgorithm());
        trainingParameters.put("Threads", Integer.toString(i3));
        trainingParameters.put("L1Cost", String.valueOf(d));
        trainingParameters.put("L2Cost", String.valueOf(d2));
        trainingParameters.put("NumOfUpdates", String.valueOf(i4));
        trainingParameters.put("MaxFctEval", String.valueOf(i5));
        POSTaggerFactory pOSTaggerFactory = new POSTaggerFactory();
        OpenNLPEncryptionFactory.getDefault().setKey(str2);
        POSModel train = POSTaggerME.train(languageCode.getAlpha3().toString(), wordTagSampleStream, trainingParameters, pOSTaggerFactory);
        BufferedOutputStream bufferedOutputStream = null;
        try {
            bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(str));
            String serialize = train.serialize(bufferedOutputStream);
            if (bufferedOutputStream != null) {
                bufferedOutputStream.close();
            }
            plainTextByLineStream.close();
            OpenNLPEncryptionFactory.getDefault().clearKey();
            return serialize;
        } catch (Throwable th) {
            if (bufferedOutputStream != null) {
                bufferedOutputStream.close();
            }
            plainTextByLineStream.close();
            OpenNLPEncryptionFactory.getDefault().clearKey();
            throw th;
        }
    }

    @Override // ai.idylnlp.models.opennlp.training.model.ModelTrainingOperations
    public String trainPerceptron(SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation, String str, LanguageCode languageCode, String str2, int i, int i2) throws IOException {
        LOGGER.info("Beginning parts-of-speech model training. Output model will be: " + str);
        PlainTextByLineStream plainTextByLineStream = new PlainTextByLineStream(new MarkableFileInputStreamFactory(new File(subjectOfTrainingOrEvaluation.getInputFile())), "UTF-8");
        WordTagSampleStream wordTagSampleStream = new WordTagSampleStream(plainTextByLineStream);
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", Integer.toString(i));
        trainingParameters.put("Iterations", Integer.toString(i2));
        trainingParameters.put("Algorithm", TrainingAlgorithm.PERCEPTRON.getAlgorithm());
        POSTaggerFactory pOSTaggerFactory = new POSTaggerFactory();
        OpenNLPEncryptionFactory.getDefault().setKey(str2);
        POSModel train = POSTaggerME.train(languageCode.getAlpha3().toString(), wordTagSampleStream, trainingParameters, pOSTaggerFactory);
        BufferedOutputStream bufferedOutputStream = null;
        try {
            bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(str));
            String serialize = train.serialize(bufferedOutputStream);
            if (bufferedOutputStream != null) {
                bufferedOutputStream.close();
            }
            plainTextByLineStream.close();
            OpenNLPEncryptionFactory.getDefault().clearKey();
            return serialize;
        } catch (Throwable th) {
            if (bufferedOutputStream != null) {
                bufferedOutputStream.close();
            }
            plainTextByLineStream.close();
            OpenNLPEncryptionFactory.getDefault().clearKey();
            throw th;
        }
    }
}
