package org.cleartk.examples.documentclassification.advanced;

import com.google.common.base.Function;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.FileFilterUtils;
import org.apache.commons.io.filefilter.HiddenFileFilter;
import org.apache.commons.io.filefilter.IOFileFilter;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.classifier.CleartkAnnotator;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.feature.transform.InstanceDataWriter;
import org.cleartk.classifier.feature.transform.InstanceStream;
import org.cleartk.classifier.feature.transform.extractor.CentroidTfidfSimilarityExtractor;
import org.cleartk.classifier.feature.transform.extractor.MinMaxNormalizationExtractor;
import org.cleartk.classifier.feature.transform.extractor.TfidfExtractor;
import org.cleartk.classifier.feature.transform.extractor.ZeroMeanUnitStddevExtractor;
import org.cleartk.classifier.jar.DefaultDataWriterFactory;
import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
import org.cleartk.classifier.jar.GenericJarClassifierFactory;
import org.cleartk.classifier.jar.JarClassifierBuilder;
import org.cleartk.classifier.libsvm.LIBSVMStringOutcomeDataWriter;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.eval.Evaluation_ImplBase;
import org.cleartk.examples.type.UsenetDocument;
import org.cleartk.syntax.opennlp.SentenceAnnotator;
import org.cleartk.token.stem.snowball.DefaultSnowballStemmer;
import org.cleartk.token.tokenizer.TokenAnnotator;
import org.cleartk.util.Options_ImplBase;
import org.cleartk.util.ae.UriToDocumentTextAnnotator;
import org.cleartk.util.cr.UriCollectionReader;
import org.kohsuke.args4j.Option;
import org.uimafit.component.ViewTextCopierAnnotator;
import org.uimafit.factory.AggregateBuilder;
import org.uimafit.factory.AnalysisEngineFactory;
import org.uimafit.factory.ConfigurationParameterFactory;
import org.uimafit.pipeline.JCasIterable;
import org.uimafit.pipeline.SimplePipeline;
import org.uimafit.testing.util.HideOutput;
import org.uimafit.util.JCasUtil;

/* loaded from: input_file:org/cleartk/examples/documentclassification/advanced/DocumentClassificationEvaluation.class */
public class DocumentClassificationEvaluation extends Evaluation_ImplBase<File, AnnotationStatistics<String>> {
    public static final String GOLD_VIEW_NAME = "DocumentClassificationGoldView";
    public static final String SYSTEM_VIEW_NAME = "_InitialView";
    private List<String> trainingArguments;

    /* loaded from: input_file:org/cleartk/examples/documentclassification/advanced/DocumentClassificationEvaluation$AnnotatorMode.class */
    public enum AnnotatorMode {
        TRAIN,
        TEST,
        CLASSIFY
    }

    /* loaded from: input_file:org/cleartk/examples/documentclassification/advanced/DocumentClassificationEvaluation$Options.class */
    public static class Options extends Options_ImplBase {

        @Option(name = "--train-dir", usage = "Specify the directory containing the training documents.  This is used for cross-validation, and for training in a holdout set evaluation. When we run this example we point to a directory containing training data from a subset of the 20 newsgroup corpus - i.e. a directory called '3news-bydate/train'")
        public File trainDirectory = new File("src/main/resources/data/3news-bydate/train");

        @Option(name = "--test-dir", usage = "Specify the directory containing the test (aka holdout/validation) documents.  This is for holdout set evaluation. When we run this example we point to a directory containing training data from a subset of the 20 newsgroup corpus - i.e. a directory called '3news-bydate/test'")
        public File testDirectory = new File("src/main/resources/data/3news-bydate/test");

        @Option(name = "--models-dir", usage = "specify the directory in which to write out the trained model files")
        public File modelsDirectory = new File("target/document_classification/models");

        @Option(name = "--training-args", usage = "specify training arguments to be passed to the learner.  For multiple values specify -ta for each - e.g. '-ta -t -ta 0'")
        public List<String> trainingArguments = Arrays.asList("-t", "0");
    }

    public static List<File> getFilesFromDirectory(File file) {
        return new ArrayList(FileUtils.listFiles(file, FileFilterUtils.makeSVNAware(HiddenFileFilter.VISIBLE), FileFilterUtils.makeSVNAware(FileFilterUtils.and(new IOFileFilter[]{FileFilterUtils.directoryFileFilter(), HiddenFileFilter.VISIBLE}))));
    }

    public static void main(String[] strArr) throws Exception {
        Options options = new Options();
        options.parseOptions(strArr);
        List<File> filesFromDirectory = getFilesFromDirectory(options.trainDirectory);
        List<File> filesFromDirectory2 = getFilesFromDirectory(options.testDirectory);
        DocumentClassificationEvaluation documentClassificationEvaluation = new DocumentClassificationEvaluation(options.modelsDirectory, options.trainingArguments);
        AnnotationStatistics addAll = AnnotationStatistics.addAll(documentClassificationEvaluation.crossValidation(filesFromDirectory, 2));
        System.err.println("Cross Validation Results:");
        System.err.print(addAll);
        System.err.println();
        System.err.println(addAll.confusions());
        System.err.println();
        AnnotationStatistics annotationStatistics = (AnnotationStatistics) documentClassificationEvaluation.trainAndTest(filesFromDirectory, filesFromDirectory2);
        System.err.println("Holdout Set Results:");
        System.err.print(annotationStatistics);
        System.err.println();
        System.err.println(annotationStatistics.confusions());
    }

    public DocumentClassificationEvaluation(File file) {
        super(file);
        this.trainingArguments = Arrays.asList(new String[0]);
    }

    public DocumentClassificationEvaluation(File file, List<String> list) {
        super(file);
        this.trainingArguments = list;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CollectionReader getCollectionReader(List<File> list) throws Exception {
        return UriCollectionReader.getCollectionReaderFromFiles(list);
    }

    public void train(CollectionReader collectionReader, File file) throws Exception {
        System.err.println("Step 1: Extracting features and writing raw instances data");
        SimplePipeline.runPipeline(collectionReader, new AnalysisEngineDescription[]{createDocumentClassificationAggregate(file, AnnotatorMode.TRAIN).createAggregateDescription()});
        Iterable loadFromDirectory = InstanceStream.loadFromDirectory(file);
        System.err.println("Collection feature normalization statistics");
        URI createTokenTfIdfDataURI = DocumentClassificationAnnotator.createTokenTfIdfDataURI(file);
        TfidfExtractor tfidfExtractor = new TfidfExtractor(DocumentClassificationAnnotator.TFIDF_EXTRACTOR_KEY);
        tfidfExtractor.train(loadFromDirectory);
        tfidfExtractor.save(createTokenTfIdfDataURI);
        URI createIdfCentroidSimilarityDataURI = DocumentClassificationAnnotator.createIdfCentroidSimilarityDataURI(file);
        CentroidTfidfSimilarityExtractor centroidTfidfSimilarityExtractor = new CentroidTfidfSimilarityExtractor(DocumentClassificationAnnotator.CENTROID_TFIDF_SIM_EXTRACTOR_KEY);
        centroidTfidfSimilarityExtractor.train(loadFromDirectory);
        centroidTfidfSimilarityExtractor.save(createIdfCentroidSimilarityDataURI);
        URI createZmusDataURI = DocumentClassificationAnnotator.createZmusDataURI(file);
        ZeroMeanUnitStddevExtractor zeroMeanUnitStddevExtractor = new ZeroMeanUnitStddevExtractor("LengthFeatures");
        zeroMeanUnitStddevExtractor.train(loadFromDirectory);
        zeroMeanUnitStddevExtractor.save(createZmusDataURI);
        URI createMinMaxDataURI = DocumentClassificationAnnotator.createMinMaxDataURI(file);
        MinMaxNormalizationExtractor minMaxNormalizationExtractor = new MinMaxNormalizationExtractor("LengthFeatures");
        minMaxNormalizationExtractor.train(loadFromDirectory);
        minMaxNormalizationExtractor.save(createMinMaxDataURI);
        System.err.println("Write out model training data");
        LIBSVMStringOutcomeDataWriter lIBSVMStringOutcomeDataWriter = new LIBSVMStringOutcomeDataWriter(file);
        Iterator it = loadFromDirectory.iterator();
        while (it.hasNext()) {
            lIBSVMStringOutcomeDataWriter.write(minMaxNormalizationExtractor.transform(zeroMeanUnitStddevExtractor.transform(centroidTfidfSimilarityExtractor.transform(tfidfExtractor.transform((Instance) it.next())))));
        }
        lIBSVMStringOutcomeDataWriter.finish();
        System.err.println("Train model and write model.jar file.");
        HideOutput hideOutput = new HideOutput();
        JarClassifierBuilder.trainAndPackage(file, (String[]) this.trainingArguments.toArray(new String[this.trainingArguments.size()]));
        hideOutput.restoreOutput();
    }

    public static AggregateBuilder createPreprocessingAggregate(File file, AnnotatorMode annotatorMode) throws ResourceInitializationException {
        AggregateBuilder aggregateBuilder = new AggregateBuilder();
        aggregateBuilder.add(UriToDocumentTextAnnotator.getDescription(), new String[0]);
        aggregateBuilder.add(SentenceAnnotator.getDescription(), new String[0]);
        aggregateBuilder.add(TokenAnnotator.getDescription(), new String[0]);
        aggregateBuilder.add(DefaultSnowballStemmer.getDescription("English"), new String[0]);
        switch (annotatorMode) {
            case TRAIN:
                aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(GoldDocumentCategoryAnnotator.class, new Object[0]), new String[0]);
                break;
            case TEST:
                aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(ViewTextCopierAnnotator.class, new Object[]{ViewTextCopierAnnotator.PARAM_SOURCE_VIEW_NAME, SYSTEM_VIEW_NAME, ViewTextCopierAnnotator.PARAM_DESTINATION_VIEW_NAME, GOLD_VIEW_NAME}), new String[0]);
                aggregateBuilder.add(AnalysisEngineFactory.createPrimitiveDescription(GoldDocumentCategoryAnnotator.class, new Object[0]), new String[]{SYSTEM_VIEW_NAME, GOLD_VIEW_NAME});
                break;
        }
        return aggregateBuilder;
    }

    public static AggregateBuilder createDocumentClassificationAggregate(File file, AnnotatorMode annotatorMode) throws ResourceInitializationException {
        AggregateBuilder createPreprocessingAggregate = createPreprocessingAggregate(file, annotatorMode);
        switch (annotatorMode) {
            case TRAIN:
                createPreprocessingAggregate.add(AnalysisEngineFactory.createPrimitiveDescription(DocumentClassificationAnnotator.class, new Object[]{DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME, InstanceDataWriter.class.getName(), DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, file.getPath()}), new String[0]);
                break;
            case TEST:
            case CLASSIFY:
            default:
                AnalysisEngineDescription createPrimitiveDescription = AnalysisEngineFactory.createPrimitiveDescription(DocumentClassificationAnnotator.class, new Object[]{CleartkAnnotator.PARAM_IS_TRAINING, false, GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, new File(file, "model.jar").getPath()});
                ConfigurationParameterFactory.addConfigurationParameters(createPrimitiveDescription, new Object[]{DocumentClassificationAnnotator.PARAM_TF_IDF_URI, DocumentClassificationAnnotator.createTokenTfIdfDataURI(file), DocumentClassificationAnnotator.PARAM_TF_IDF_CENTROID_SIMILARITY_URI, DocumentClassificationAnnotator.createIdfCentroidSimilarityDataURI(file), DocumentClassificationAnnotator.PARAM_MINMAX_URI, DocumentClassificationAnnotator.createMinMaxDataURI(file), DocumentClassificationAnnotator.PARAM_ZMUS_URI, DocumentClassificationAnnotator.createZmusDataURI(file)});
                createPreprocessingAggregate.add(createPrimitiveDescription, new String[0]);
                break;
        }
        return createPreprocessingAggregate;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: test, reason: merged with bridge method [inline-methods] */
    public AnnotationStatistics<String> m3test(CollectionReader collectionReader, File file) throws Exception {
        AnnotationStatistics<String> annotationStatistics = new AnnotationStatistics<>();
        AnalysisEngine createAggregate = createDocumentClassificationAggregate(file, AnnotatorMode.TEST).createAggregate();
        Function annotationToSpan = AnnotationStatistics.annotationToSpan();
        Function annotationToFeatureValue = AnnotationStatistics.annotationToFeatureValue("category");
        Iterator it = new JCasIterable(collectionReader, new AnalysisEngine[]{createAggregate}).iterator();
        while (it.hasNext()) {
            JCas jCas = (JCas) it.next();
            annotationStatistics.add(JCasUtil.select(jCas.getView(GOLD_VIEW_NAME), UsenetDocument.class), JCasUtil.select(jCas.getView(SYSTEM_VIEW_NAME), UsenetDocument.class), annotationToSpan, annotationToFeatureValue);
        }
        return annotationStatistics;
    }
}
