001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import org.tribuo.classification.Label;
025import org.tribuo.classification.sequence.example.SequenceDataGenerator;
026import org.tribuo.sequence.SequenceDataset;
027import org.tribuo.sequence.SequenceModel;
028import org.tribuo.sequence.SequenceTrainer;
029import org.tribuo.util.Util;
030
031import java.io.BufferedInputStream;
032import java.io.FileInputStream;
033import java.io.FileOutputStream;
034import java.io.IOException;
035import java.io.ObjectInputStream;
036import java.io.ObjectOutputStream;
037import java.nio.file.Path;
038import java.util.logging.Logger;
039
040/**
041 * Build and run a sequence classifier on a generated or serialized dataset using the trainer specified in the configuration file.
042 */
043public class SeqTrainTest {
044
045    private static final Logger logger = Logger.getLogger(SeqTrainTest.class.getName());
046
047    /**
048     * Command line options.
049     */
050    public static class SeqTrainTestOptions implements Options {
051        @Override
052        public String getOptionsDescription() {
053            return "Trains and tests a sequence classification model on the specified dataset.";
054        }
055
056        /**
057         * Name of the example dataset, options are {gorilla}.
058         */
059        @Option(charName = 'd', longName = "dataset-name", usage = "Name of the example dataset, options are {gorilla}.")
060        public String datasetName = "";
061        /**
062         * Path to serialize model to.
063         */
064        @Option(charName = 'f', longName = "output-path", usage = "Path to serialize model to.")
065        public Path outputPath;
066        /**
067         * Path to a serialised SequenceDataset used for training.
068         */
069        @Option(charName = 'u', longName = "train-dataset", usage = "Path to a serialised SequenceDataset used for training.")
070        public Path trainDataset = null;
071        /**
072         * Path to a serialised SequenceDataset used for testing.
073         */
074        @Option(charName = 'v', longName = "test-dataset", usage = "Path to a serialised SequenceDataset used for testing.")
075        public Path testDataset = null;
076        /**
077         * Name of the trainer in the configuration file.
078         */
079        @Option(charName = 't', longName = "trainer-name", usage = "Name of the trainer in the configuration file.")
080        public SequenceTrainer<Label> trainer;
081    }
082
083    /**
084     * @param args the command line arguments
085     * @throws ClassNotFoundException if it failed to load the model.
086     * @throws IOException            if there is any error reading the examples.
087     */
088    public static void main(String[] args) throws ClassNotFoundException, IOException {
089
090        //
091        // Use the labs format logging.
092        LabsLogFormatter.setAllLogFormatters();
093
094        SeqTrainTestOptions o = new SeqTrainTestOptions();
095        ConfigurationManager cm;
096        try {
097            cm = new ConfigurationManager(args, o);
098        } catch (UsageException e) {
099            logger.info(e.getMessage());
100            return;
101        }
102
103        SequenceDataset<Label> train;
104        SequenceDataset<Label> test;
105        switch (o.datasetName) {
106            case "Gorilla":
107            case "gorilla":
108                logger.info("Generating gorilla dataset");
109                train = SequenceDataGenerator.generateGorillaDataset(1);
110                test = SequenceDataGenerator.generateGorillaDataset(1);
111                break;
112            default:
113                if ((o.trainDataset != null) && (o.testDataset != null)) {
114                    logger.info("Loading training data from " + o.trainDataset);
115                    try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile())));
116                         ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) {
117                        @SuppressWarnings("unchecked") // deserialising a generic dataset.
118                        SequenceDataset<Label> tmpTrain = (SequenceDataset<Label>) ois.readObject();
119                        train = tmpTrain;
120                        logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
121                        logger.info("Found " + train.getFeatureIDMap().size() + " features");
122                        logger.info("Loading testing data from " + o.testDataset);
123                        @SuppressWarnings("unchecked") // deserialising a generic dataset.
124                        SequenceDataset<Label> tmpTest = (SequenceDataset<Label>) oits.readObject();
125                        test = tmpTest;
126                        logger.info(String.format("Loaded %d testing examples", test.size()));
127                    }
128                } else {
129                    logger.warning("Unknown dataset " + o.datasetName);
130                    logger.info(cm.usage());
131                    return;
132                }
133        }
134
135        logger.info("Training using " + o.trainer.toString());
136        final long trainStart = System.currentTimeMillis();
137        SequenceModel<Label> model = o.trainer.train(train);
138        final long trainStop = System.currentTimeMillis();
139        logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop));
140
141        LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator();
142        final long testStart = System.currentTimeMillis();
143        LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model,test);
144        final long testStop = System.currentTimeMillis();
145        logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop));
146        System.out.println(evaluation.toString());
147        System.out.println();
148        System.out.println(evaluation.getConfusionMatrix().toString());
149
150        if (o.outputPath != null) {
151            try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) {
152                oos.writeObject(model);
153                logger.info("Serialized model to file: " + o.outputPath);
154            }
155        }
156    }
157}