package ai.idylnlp.models.opennlp.training;

import ai.idylnlp.model.nlp.subjects.SubjectOfTrainingOrEvaluation;
import ai.idylnlp.model.training.FMeasure;
import ai.idylnlp.model.training.FMeasureModelValidationResult;
import ai.idylnlp.models.ModelOperationsUtils;
import ai.idylnlp.models.opennlp.training.model.ModelCrossValidationOperations;
import ai.idylnlp.models.opennlp.training.model.ModelSeparateDataValidationOperations;
import ai.idylnlp.models.opennlp.training.model.ModelTrainingOperations;
import ai.idylnlp.models.opennlp.training.model.TrainingAlgorithm;
import ai.idylnlp.opennlp.custom.encryption.OpenNLPEncryptionFactory;
import ai.idylnlp.training.definition.model.TrainingDefinitionReader;
import com.neovisionaries.i18n.LanguageCode;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.regex.Pattern;
import opennlp.tools.cmdline.tokenizer.TokenEvaluationErrorListener;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.tokenize.TokenSampleStream;
import opennlp.tools.tokenize.TokenizerCrossValidator;
import opennlp.tools.tokenize.TokenizerEvaluationMonitor;
import opennlp.tools.tokenize.TokenizerEvaluator;
import opennlp.tools.tokenize.TokenizerFactory;
import opennlp.tools.tokenize.TokenizerME;
import opennlp.tools.tokenize.TokenizerModel;
import opennlp.tools.util.MarkableFileInputStreamFactory;
import opennlp.tools.util.PlainTextByLineStream;
import opennlp.tools.util.TrainingParameters;
import org.apache.commons.lang.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:ai/idylnlp/models/opennlp/training/TokenModelOperations.class */
public class TokenModelOperations implements ModelTrainingOperations, ModelSeparateDataValidationOperations<FMeasureModelValidationResult>, ModelCrossValidationOperations<FMeasureModelValidationResult> {
    private static final Logger LOGGER = LogManager.getLogger(TokenModelOperations.class);

    public static String train(TrainingDefinitionReader trainingDefinitionReader) throws IOException {
        TokenModelOperations tokenModelOperations = new TokenModelOperations();
        SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation = ModelOperationsUtils.getSubjectOfTrainingOrEvaluation(trainingDefinitionReader);
        String file = trainingDefinitionReader.getTrainingDefinition().getModel().getFile();
        String language = trainingDefinitionReader.getTrainingDefinition().getModel().getLanguage();
        String encryptionkey = trainingDefinitionReader.getTrainingDefinition().getModel().getEncryptionkey();
        int intValue = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getCutoff().intValue();
        int intValue2 = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getIterations().intValue();
        int intValue3 = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getThreads().intValue();
        String name = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getName();
        LanguageCode byCodeIgnoreCase = LanguageCode.getByCodeIgnoreCase(language);
        if (name.equalsIgnoreCase(TrainingAlgorithm.PERCEPTRON.getName())) {
            return tokenModelOperations.trainPerceptron(subjectOfTrainingOrEvaluation, file, byCodeIgnoreCase, encryptionkey, intValue, intValue2);
        }
        if (name.equalsIgnoreCase(TrainingAlgorithm.MAXENT_QN.getName())) {
            return tokenModelOperations.trainMaxEntQN(subjectOfTrainingOrEvaluation, file, byCodeIgnoreCase, encryptionkey, intValue, intValue2, intValue3, trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getL1().doubleValue(), trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getL2().doubleValue(), trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getM().intValue(), trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getMax().intValue());
        }
        throw new IOException("Invalid algorithm specified in the training definition file: " + name);
    }

    public static FMeasureModelValidationResult crossValidate(TrainingDefinitionReader trainingDefinitionReader, int i) throws IOException {
        FMeasureModelValidationResult crossValidationEvaluateMaxEntQN;
        String language = trainingDefinitionReader.getTrainingDefinition().getModel().getLanguage();
        int intValue = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getIterations().intValue();
        int intValue2 = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getCutoff().intValue();
        String name = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getName();
        double doubleValue = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getL1().doubleValue();
        double doubleValue2 = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getL2().doubleValue();
        int intValue3 = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getM().intValue();
        int intValue4 = trainingDefinitionReader.getTrainingDefinition().getAlgorithm().getMax().intValue();
        LanguageCode byCodeIgnoreCase = LanguageCode.getByCodeIgnoreCase(language);
        SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation = ModelOperationsUtils.getSubjectOfTrainingOrEvaluation(trainingDefinitionReader);
        TokenModelOperations tokenModelOperations = new TokenModelOperations();
        if (StringUtils.equalsIgnoreCase(name, TrainingAlgorithm.PERCEPTRON.getName())) {
            crossValidationEvaluateMaxEntQN = tokenModelOperations.crossValidationEvaluatePerceptron(subjectOfTrainingOrEvaluation, byCodeIgnoreCase, intValue, intValue2, i);
        } else {
            if (!StringUtils.equalsIgnoreCase(name, TrainingAlgorithm.MAXENT_QN.getName())) {
                throw new IOException("Invalid algorithm specified in the training definition file: " + name);
            }
            crossValidationEvaluateMaxEntQN = tokenModelOperations.crossValidationEvaluateMaxEntQN(subjectOfTrainingOrEvaluation, byCodeIgnoreCase, intValue, intValue2, i, doubleValue, doubleValue2, intValue3, intValue4);
        }
        return crossValidationEvaluateMaxEntQN;
    }

    @Override // ai.idylnlp.models.opennlp.training.model.ModelCrossValidationOperations
    public FMeasureModelValidationResult crossValidationEvaluateMaxEntQN(SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation, LanguageCode languageCode, int i, int i2, int i3, double d, double d2, int i4, int i5) throws IOException {
        LOGGER.info("Doing model evaluation using cross-validation with {} folds.", Integer.valueOf(i3));
        TokenSampleStream tokenSampleStream = new TokenSampleStream(new PlainTextByLineStream(new MarkableFileInputStreamFactory(new File(subjectOfTrainingOrEvaluation.getInputFile())), "UTF-8"));
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", Integer.toString(i2));
        trainingParameters.put("Iterations", Integer.toString(i));
        trainingParameters.put("Algorithm", TrainingAlgorithm.MAXENT_QN.getAlgorithm());
        trainingParameters.put("L1Cost", String.valueOf(d));
        trainingParameters.put("L2Cost", String.valueOf(d2));
        trainingParameters.put("NumOfUpdates", String.valueOf(i4));
        trainingParameters.put("MaxFctEval", String.valueOf(i5));
        TokenizerCrossValidator tokenizerCrossValidator = new TokenizerCrossValidator(trainingParameters, new TokenizerFactory(languageCode.getAlpha3().toString(), (Dictionary) null, false, (Pattern) null), new TokenizerEvaluationMonitor[]{new TokenEvaluationErrorListener()});
        tokenizerCrossValidator.evaluate(tokenSampleStream, i3);
        return new FMeasureModelValidationResult(new FMeasure(tokenizerCrossValidator.getFMeasure().getPrecisionScore(), tokenizerCrossValidator.getFMeasure().getRecallScore(), tokenizerCrossValidator.getFMeasure().getFMeasure()));
    }

    @Override // ai.idylnlp.models.opennlp.training.model.ModelCrossValidationOperations
    public FMeasureModelValidationResult crossValidationEvaluatePerceptron(SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation, LanguageCode languageCode, int i, int i2, int i3) throws IOException {
        LOGGER.info("Doing model evaluation using cross-validation with {} folds.", Integer.valueOf(i3));
        TokenSampleStream tokenSampleStream = new TokenSampleStream(new PlainTextByLineStream(new MarkableFileInputStreamFactory(new File(subjectOfTrainingOrEvaluation.getInputFile())), "UTF-8"));
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", Integer.toString(i2));
        trainingParameters.put("Iterations", Integer.toString(i));
        trainingParameters.put("Algorithm", TrainingAlgorithm.PERCEPTRON.getAlgorithm());
        TokenizerCrossValidator tokenizerCrossValidator = new TokenizerCrossValidator(trainingParameters, new TokenizerFactory(languageCode.getAlpha3().toString(), (Dictionary) null, false, (Pattern) null), new TokenizerEvaluationMonitor[]{new TokenEvaluationErrorListener()});
        tokenizerCrossValidator.evaluate(tokenSampleStream, i3);
        return new FMeasureModelValidationResult(new FMeasure(tokenizerCrossValidator.getFMeasure().getPrecisionScore(), tokenizerCrossValidator.getFMeasure().getRecallScore(), tokenizerCrossValidator.getFMeasure().getFMeasure()));
    }

    @Override // ai.idylnlp.models.opennlp.training.model.ModelSeparateDataValidationOperations
    public FMeasureModelValidationResult separateDataEvaluate(SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation, String str, String str2) throws IOException {
        LOGGER.info("Doing model evaluation using separate training data.");
        OpenNLPEncryptionFactory.getDefault().setKey(str2);
        TokenSampleStream tokenSampleStream = new TokenSampleStream(new PlainTextByLineStream(new MarkableFileInputStreamFactory(new File(subjectOfTrainingOrEvaluation.getInputFile())), "UTF-8"));
        TokenizerEvaluator tokenizerEvaluator = new TokenizerEvaluator(new TokenizerME(new TokenizerModel(new File(str))), new TokenizerEvaluationMonitor[0]);
        tokenizerEvaluator.evaluate(tokenSampleStream);
        OpenNLPEncryptionFactory.getDefault().clearKey();
        return new FMeasureModelValidationResult(new FMeasure(tokenizerEvaluator.getFMeasure().getPrecisionScore(), tokenizerEvaluator.getFMeasure().getRecallScore(), tokenizerEvaluator.getFMeasure().getFMeasure()));
    }

    @Override // ai.idylnlp.models.opennlp.training.model.ModelTrainingOperations
    public String trainMaxEntQN(SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation, String str, LanguageCode languageCode, String str2, int i, int i2, int i3, double d, double d2, int i4, int i5) throws IOException {
        LOGGER.info("Beginning tokenizer model training. Output model will be: " + str);
        PlainTextByLineStream plainTextByLineStream = new PlainTextByLineStream(new MarkableFileInputStreamFactory(new File(subjectOfTrainingOrEvaluation.getInputFile())), "UTF-8");
        TokenSampleStream tokenSampleStream = new TokenSampleStream(plainTextByLineStream);
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", Integer.toString(i));
        trainingParameters.put("Iterations", Integer.toString(i2));
        trainingParameters.put("Algorithm", TrainingAlgorithm.MAXENT_QN.getAlgorithm());
        trainingParameters.put("Threads", Integer.toString(i3));
        trainingParameters.put("L1Cost", String.valueOf(d));
        trainingParameters.put("L2Cost", String.valueOf(d2));
        trainingParameters.put("NumOfUpdates", String.valueOf(i4));
        trainingParameters.put("MaxFctEval", String.valueOf(i5));
        TokenizerFactory tokenizerFactory = new TokenizerFactory(languageCode.getAlpha3().toString(), new Dictionary(), false, (Pattern) null);
        OpenNLPEncryptionFactory.getDefault().setKey(str2);
        TokenizerModel train = TokenizerME.train(tokenSampleStream, tokenizerFactory, trainingParameters);
        BufferedOutputStream bufferedOutputStream = null;
        try {
            bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(str));
            String serialize = train.serialize(bufferedOutputStream);
            if (bufferedOutputStream != null) {
                bufferedOutputStream.close();
            }
            plainTextByLineStream.close();
            OpenNLPEncryptionFactory.getDefault().clearKey();
            return serialize;
        } catch (Throwable th) {
            if (bufferedOutputStream != null) {
                bufferedOutputStream.close();
            }
            plainTextByLineStream.close();
            OpenNLPEncryptionFactory.getDefault().clearKey();
            throw th;
        }
    }

    @Override // ai.idylnlp.models.opennlp.training.model.ModelTrainingOperations
    public String trainPerceptron(SubjectOfTrainingOrEvaluation subjectOfTrainingOrEvaluation, String str, LanguageCode languageCode, String str2, int i, int i2) throws IOException {
        LOGGER.info("Beginning tokenizer model training. Output model will be: " + str);
        PlainTextByLineStream plainTextByLineStream = new PlainTextByLineStream(new MarkableFileInputStreamFactory(new File(subjectOfTrainingOrEvaluation.getInputFile())), "UTF-8");
        TokenSampleStream tokenSampleStream = new TokenSampleStream(plainTextByLineStream);
        TrainingParameters trainingParameters = new TrainingParameters();
        trainingParameters.put("Cutoff", Integer.toString(i));
        trainingParameters.put("Iterations", Integer.toString(i2));
        trainingParameters.put("Algorithm", TrainingAlgorithm.PERCEPTRON.getAlgorithm());
        TokenizerFactory tokenizerFactory = new TokenizerFactory(languageCode.getAlpha3().toString(), new Dictionary(), false, (Pattern) null);
        OpenNLPEncryptionFactory.getDefault().setKey(str2);
        TokenizerModel train = TokenizerME.train(tokenSampleStream, tokenizerFactory, trainingParameters);
        BufferedOutputStream bufferedOutputStream = null;
        try {
            bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(str));
            String serialize = train.serialize(bufferedOutputStream);
            if (bufferedOutputStream != null) {
                bufferedOutputStream.close();
            }
            plainTextByLineStream.close();
            OpenNLPEncryptionFactory.getDefault().clearKey();
            return serialize;
        } catch (Throwable th) {
            if (bufferedOutputStream != null) {
                bufferedOutputStream.close();
            }
            plainTextByLineStream.close();
            OpenNLPEncryptionFactory.getDefault().clearKey();
            throw th;
        }
    }
}
