001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence.viterbi; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.Feature; 024import org.tribuo.Model; 025import org.tribuo.Trainer; 026import org.tribuo.classification.Label; 027import org.tribuo.classification.sequence.viterbi.ViterbiModel.ScoreAggregation; 028import org.tribuo.provenance.ModelProvenance; 029import org.tribuo.provenance.TrainerProvenance; 030import org.tribuo.provenance.impl.TrainerProvenanceImpl; 031import org.tribuo.sequence.ImmutableSequenceDataset; 032import org.tribuo.sequence.MutableSequenceDataset; 033import org.tribuo.sequence.SequenceDataset; 034import org.tribuo.sequence.SequenceExample; 035import org.tribuo.sequence.SequenceModel; 036import org.tribuo.sequence.SequenceTrainer; 037 038import java.time.OffsetDateTime; 039import java.util.ArrayList; 040import java.util.List; 041import java.util.Map; 042 043/** 044 * Builds a Viterbi model using the supplied {@link Trainer}. 045 * Has a parameter to control the label features which are added to the features supplied by the data. 046 */ 047public final class ViterbiTrainer implements SequenceTrainer<Label> { 048 049 @Config(mandatory = true, description = "Inner trainer for each sequence element.") 050 private Trainer<Label> trainer; 051 052 @Config(mandatory = true, description = "Feature extractor to pull in surrounding label features.") 053 private LabelFeatureExtractor labelFeatureExtractor; 054 055 @Config(mandatory = true, description = "Number of candidate paths.") 056 private int stackSize; 057 058 @Config(mandatory = true, description = "Score aggregation function.") 059 private ScoreAggregation scoreAggregation; 060 061 private int trainInvocationCounter = 0; 062 063 /** 064 * Constructs a ViterbiTrainer wrapping the specified trainer, with an unbounded stack size. 065 * @param trainer The trainer to wrap. 066 * @param labelFeatureExtractor The feature extraction function for labels. 067 * @param scoreAggregation The score aggregation function. 068 */ 069 public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, 070 ScoreAggregation scoreAggregation) { 071 this(trainer, labelFeatureExtractor, -1, scoreAggregation); 072 } 073 074 /** 075 * Constructs a ViterbiTrainer wrapping the specified trainer. 076 * @param trainer The trainer to wrap. 077 * @param labelFeatureExtractor The feature extraction function for labels. 078 * @param stackSize The stack size. 079 * @param scoreAggregation The score aggregation function. 080 */ 081 public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, int stackSize, 082 ScoreAggregation scoreAggregation) { 083 this.trainer = trainer; 084 this.labelFeatureExtractor = labelFeatureExtractor; 085 this.stackSize = stackSize; 086 this.scoreAggregation = scoreAggregation; 087 } 088 089 /** 090 * For OLCUT. 091 */ 092 private ViterbiTrainer() { } 093 094 /** 095 * The viterbi train method is unique because it delegates to a regular 096 * {@link Model} train method, but before it does, it adds features derived 097 * from preceding labels. The pipeline upstream of this call should not care 098 * that these features are being added - that is, we would not want to make 099 * the upstream logic worry about what kind of trainer will be used and have 100 * conditional logic that says to add special label-derived features if 101 * using the ViterbiTrainer. So, these one-of-a-kind unique-in-the-world 102 * label-derived features are generated here and added to the sequence 103 * examples of the passed in dataset. If you pass in a 104 * MutableSequenceDataset, then please be aware that your dataset will be 105 * modified after calling this method and therefore subsequent calls to 106 * other SequenceModel.train methods with your dataset should be avoided. If 107 * you pass in an ImmutableSequenceDataset, then please be aware that your 108 * entire dataset is going to be copied as a MutableSequenceDataset - so 109 * there is a memory penalty. 110 * @param dataset The input dataset. 111 * @param runProvenance Any additional information to record in the provenance. 112 * @return A {@link SequenceModel} using Viterbi wrapped around an inner {@link Model}. 113 */ 114 @Override 115 public SequenceModel<Label> train(SequenceDataset<Label> dataset, Map<String, Provenance> runProvenance) { 116 if (dataset.getOutputInfo().getUnknownCount() > 0) { 117 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 118 } 119 // if stack size isn't specified, then we will calculate it based on the 120 // number of unique output values 121 if (stackSize == -1) { 122 stackSize = dataset.getOutputIDInfo().size(); 123 } 124 125 // create a copy of the dataset to a mutable one. See note above. 126 if (dataset instanceof ImmutableSequenceDataset) { 127 dataset = new MutableSequenceDataset<>((ImmutableSequenceDataset<Label>) dataset); 128 } 129 130 if (!(dataset instanceof MutableSequenceDataset)) { 131 throw new IllegalArgumentException("unable to handle sub-type of dataset: " + dataset.getClass().getName()); 132 } 133 134 for (SequenceExample<Label> sequenceExample : dataset) { 135 List<Label> labels = new ArrayList<>(); 136 137 for (Example<Label> example : sequenceExample) { 138 List<Feature> labelFeatures = extractFeatures(labels, (MutableSequenceDataset<Label>) dataset, 139 1.0); 140 example.addAll(labelFeatures); 141 labels.add(example.getOutput()); 142 } 143 } 144 145 TrainerProvenance trainerProvenance = getProvenance(); 146 ModelProvenance provenance = new ModelProvenance(ViterbiModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), trainerProvenance, runProvenance); 147 trainInvocationCounter++; 148 Dataset<Label> flatData = dataset.getFlatDataset(); 149 Model<Label> model = trainer.train(flatData); 150 return new ViterbiModel("viterbi+" + model.getName(), provenance, model, 151 labelFeatureExtractor, stackSize, scoreAggregation); 152 } 153 154 @Override 155 public int getInvocationCount() { 156 return trainInvocationCounter; 157 } 158 159 private List<Feature> extractFeatures(List<Label> labels, 160 MutableSequenceDataset<Label> dataset, double value) { 161 List<Feature> labelFeatures = new ArrayList<>(); 162 for (Feature labelFeature : labelFeatureExtractor.extractFeatures(labels, value)) { 163 dataset.getFeatureMap().add(labelFeature.getName(), labelFeature.getValue()); 164 labelFeatures.add(labelFeature); 165 } 166 return labelFeatures; 167 } 168 169 @Override 170 public String toString() { 171 return "ViterbiTrainer(innerTrainer=" + trainer.toString() + ",labelFeatureExtractor=" + labelFeatureExtractor.toString() + ")"; 172 } 173 174 @Override 175 public TrainerProvenance getProvenance() { 176 return new TrainerProvenanceImpl(this); 177 } 178}