001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import org.tribuo.classification.Label; 025import org.tribuo.classification.sequence.example.SequenceDataGenerator; 026import org.tribuo.sequence.SequenceDataset; 027import org.tribuo.sequence.SequenceModel; 028import org.tribuo.sequence.SequenceTrainer; 029import org.tribuo.util.Util; 030 031import java.io.BufferedInputStream; 032import java.io.FileInputStream; 033import java.io.FileOutputStream; 034import java.io.IOException; 035import java.io.ObjectInputStream; 036import java.io.ObjectOutputStream; 037import java.nio.file.Path; 038import java.util.logging.Logger; 039 040/** 041 * Build and run a sequence classifier on a generated or serialized dataset using the trainer specified in the configuration file. 042 */ 043public class SeqTrainTest { 044 045 private static final Logger logger = Logger.getLogger(SeqTrainTest.class.getName()); 046 047 /** 048 * Command line options. 049 */ 050 public static class SeqTrainTestOptions implements Options { 051 @Override 052 public String getOptionsDescription() { 053 return "Trains and tests a sequence classification model on the specified dataset."; 054 } 055 056 /** 057 * Name of the example dataset, options are {gorilla}. 058 */ 059 @Option(charName = 'd', longName = "dataset-name", usage = "Name of the example dataset, options are {gorilla}.") 060 public String datasetName = ""; 061 /** 062 * Path to serialize model to. 063 */ 064 @Option(charName = 'f', longName = "output-path", usage = "Path to serialize model to.") 065 public Path outputPath; 066 /** 067 * Path to a serialised SequenceDataset used for training. 068 */ 069 @Option(charName = 'u', longName = "train-dataset", usage = "Path to a serialised SequenceDataset used for training.") 070 public Path trainDataset = null; 071 /** 072 * Path to a serialised SequenceDataset used for testing. 073 */ 074 @Option(charName = 'v', longName = "test-dataset", usage = "Path to a serialised SequenceDataset used for testing.") 075 public Path testDataset = null; 076 /** 077 * Name of the trainer in the configuration file. 078 */ 079 @Option(charName = 't', longName = "trainer-name", usage = "Name of the trainer in the configuration file.") 080 public SequenceTrainer<Label> trainer; 081 } 082 083 /** 084 * @param args the command line arguments 085 * @throws ClassNotFoundException if it failed to load the model. 086 * @throws IOException if there is any error reading the examples. 087 */ 088 public static void main(String[] args) throws ClassNotFoundException, IOException { 089 090 // 091 // Use the labs format logging. 092 LabsLogFormatter.setAllLogFormatters(); 093 094 SeqTrainTestOptions o = new SeqTrainTestOptions(); 095 ConfigurationManager cm; 096 try { 097 cm = new ConfigurationManager(args, o); 098 } catch (UsageException e) { 099 logger.info(e.getMessage()); 100 return; 101 } 102 103 SequenceDataset<Label> train; 104 SequenceDataset<Label> test; 105 switch (o.datasetName) { 106 case "Gorilla": 107 case "gorilla": 108 logger.info("Generating gorilla dataset"); 109 train = SequenceDataGenerator.generateGorillaDataset(1); 110 test = SequenceDataGenerator.generateGorillaDataset(1); 111 break; 112 default: 113 if ((o.trainDataset != null) && (o.testDataset != null)) { 114 logger.info("Loading training data from " + o.trainDataset); 115 try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.trainDataset.toFile()))); 116 ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(o.testDataset.toFile())))) { 117 @SuppressWarnings("unchecked") // deserialising a generic dataset. 118 SequenceDataset<Label> tmpTrain = (SequenceDataset<Label>) ois.readObject(); 119 train = tmpTrain; 120 logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString())); 121 logger.info("Found " + train.getFeatureIDMap().size() + " features"); 122 logger.info("Loading testing data from " + o.testDataset); 123 @SuppressWarnings("unchecked") // deserialising a generic dataset. 124 SequenceDataset<Label> tmpTest = (SequenceDataset<Label>) oits.readObject(); 125 test = tmpTest; 126 logger.info(String.format("Loaded %d testing examples", test.size())); 127 } 128 } else { 129 logger.warning("Unknown dataset " + o.datasetName); 130 logger.info(cm.usage()); 131 return; 132 } 133 } 134 135 logger.info("Training using " + o.trainer.toString()); 136 final long trainStart = System.currentTimeMillis(); 137 SequenceModel<Label> model = o.trainer.train(train); 138 final long trainStop = System.currentTimeMillis(); 139 logger.info("Finished training classifier " + Util.formatDuration(trainStart, trainStop)); 140 141 LabelSequenceEvaluator labelEvaluator = new LabelSequenceEvaluator(); 142 final long testStart = System.currentTimeMillis(); 143 LabelSequenceEvaluation evaluation = labelEvaluator.evaluate(model,test); 144 final long testStop = System.currentTimeMillis(); 145 logger.info("Finished evaluating model " + Util.formatDuration(testStart, testStop)); 146 System.out.println(evaluation.toString()); 147 System.out.println(); 148 System.out.println(evaluation.getConfusionMatrix().toString()); 149 150 if (o.outputPath != null) { 151 try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(o.outputPath.toFile()))) { 152 oos.writeObject(model); 153 logger.info("Serialized model to file: " + o.outputPath); 154 } 155 } 156 } 157}