001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence.viterbi; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Feature; 022import org.tribuo.Model; 023import org.tribuo.Prediction; 024import org.tribuo.classification.Label; 025import org.tribuo.provenance.ModelProvenance; 026import org.tribuo.sequence.SequenceDataset; 027import org.tribuo.sequence.SequenceExample; 028import org.tribuo.sequence.SequenceModel; 029 030import java.util.ArrayList; 031import java.util.Collection; 032import java.util.Collections; 033import java.util.Comparator; 034import java.util.HashMap; 035import java.util.List; 036import java.util.Map; 037import java.util.stream.Collectors; 038 039/** 040 * An implementation of a viterbi model. 041 */ 042public class ViterbiModel extends SequenceModel<Label> { 043 044 private static final long serialVersionUID = 1L; 045 046 /** 047 * Types of label score aggregation. 048 */ 049 public enum ScoreAggregation { 050 /** 051 * Adds the scores. 052 */ 053 ADD, 054 /** 055 * Multiplies the scores. 056 */ 057 MULTIPLY 058 } 059 060 private final Model<Label> model; 061 062 private final LabelFeatureExtractor labelFeatureExtractor; 063 064 /** 065 * Specifies the maximum number of candidate paths to keep track of. In general, this number 066 * should be higher than the number of possible classifications at any given point in the 067 * sequence. This guarantees that highest-possible scoring sequence will be returned. If, 068 * however, the number of possible classifications is quite high and/or you are concerned about 069 * throughput performance, then you may want to reduce the number of candidate paths to 070 * maintain. 071 */ 072 private final int stackSize; 073 074 /** 075 * Specifies the score aggregation algorithm. 076 */ 077 private final ScoreAggregation scoreAggregation; 078 079 ViterbiModel(String name, ModelProvenance description, 080 Model<Label> model, LabelFeatureExtractor labelFeatureExtractor, int stackSize, ScoreAggregation scoreAggregation) { 081 super(name, description, model.getFeatureIDMap(), model.getOutputIDInfo()); 082 this.model = model; 083 this.labelFeatureExtractor = labelFeatureExtractor; 084 this.stackSize = stackSize; 085 this.scoreAggregation = scoreAggregation; 086 } 087 088 @Override 089 public List<List<Prediction<Label>>> predict(SequenceDataset<Label> examples) { 090 List<List<Prediction<Label>>> predictions = new ArrayList<>(); 091 for (SequenceExample<Label> e : examples) { 092 predictions.add(predict(e)); 093 } 094 return predictions; 095 } 096 097 @Override 098 public List<Prediction<Label>> predict(SequenceExample<Label> examples) { 099 if (stackSize == 1) { 100 List<Label> labels = new ArrayList<>(); 101 List<Prediction<Label>> returnValues = new ArrayList<>(); 102 for (Example<Label> example : examples) { 103 List<Feature> labelFeatures = extractFeatures(labels); 104 example.addAll(labelFeatures); 105 Prediction<Label> prediction = model.predict(example); 106 labels.add(prediction.getOutput()); 107 returnValues.add(prediction); 108 } 109 return returnValues; 110 } else { 111 return viterbi(examples); 112 } 113 114 } 115 116 private List<Feature> extractFeatures(List<Label> labels) { 117 List<Feature> labelFeatures = new ArrayList<>(); 118 for (Feature labelFeature : labelFeatureExtractor.extractFeatures(labels, 1.0)) { 119 int id = featureIDMap.getID(labelFeature.getName()); 120 if (id > -1) { 121 labelFeatures.add(labelFeature); 122 } 123 } 124 return labelFeatures; 125 } 126 127 /** 128 * This implementation of Viterbi requires at most stackSize * sequenceLength calls to the 129 * classifier. If this proves to be too expensive, then consider using a smaller stack size. 130 * 131 * @param examples a sequence-worth of features. Each {@code List<Feature>} in features should correspond to 132 * all of the features for a given element in a sequence to be classified. 133 * @return a list of Predictions - one for each member of the sequence. 134 * @see LabelFeatureExtractor 135 */ 136 private List<Prediction<Label>> viterbi(SequenceExample<Label> examples) { 137 // find the best paths through the label lattice 138 Collection<Path> paths = null; 139 int[] numUsed = new int[examples.size()]; 140 int i = 0; 141 for (Example<Label> example : examples) { 142 // if this is the first instance, start new paths for each label 143 if (paths == null) { 144 paths = new ArrayList<>(); 145 Prediction<Label> prediction = this.model.predict(example); 146 numUsed[i] = prediction.getNumActiveFeatures(); 147 Map<String, Label> distribution = prediction.getOutputScores(); 148 for (Label label : this.getTopLabels(distribution)) { 149 paths.add(new Path(label, label.getScore(), null)); 150 } 151 } else { 152 // for later instances, find the best previous path for each label 153 Map<Label, Path> maxPaths = new HashMap<>(); 154 for (Path path : paths) { 155 Example<Label> clonedExample = example.copy(); 156 List<Label> previousLabels = new ArrayList<>(path.labels); 157 List<Feature> labelFeatures = extractFeatures(previousLabels); 158 clonedExample.addAll(labelFeatures); 159 Prediction<Label> prediction = this.model.predict(clonedExample); 160 // TODO this isn't quite correct as it includes label features. 161 numUsed[i] = prediction.getNumActiveFeatures(); 162 Map<String, Label> distribution = prediction.getOutputScores(); 163 164 for (Label label : this.getTopLabels(distribution)) { 165 double labelScore = label.getScore(); 166 double score = this.scoreAggregation == ScoreAggregation.ADD ? path.score + labelScore : path.score * labelScore; 167 Path maxPath = maxPaths.get(label); 168 if (maxPath == null || score > maxPath.score) { 169 maxPaths.put(label, new Path(label, score, path)); 170 } 171 } 172 } 173 paths = maxPaths.values(); 174 } 175 i++; 176 } 177 178 Path maxPath = Collections.max(paths); 179 180 ArrayList<Prediction<Label>> output = new ArrayList<>(); 181 182 for (int j = 0; j < examples.size(); j++) { 183 Example<Label> e = examples.get(j); 184 output.add(new Prediction<>(maxPath.labels.get(j), numUsed[j], e)); 185 } 186 187 return output; 188 } 189 190 protected List<Label> getTopLabels(Map<String, Label> distribution) { 191 return getTopLabels(distribution, this.stackSize); 192 } 193 194 protected static List<Label> getTopLabels(Map<String, Label> distribution, int stackSize) { 195 return distribution.values().stream().sorted(Comparator.comparingDouble(Label::getScore).reversed()).limit(stackSize) 196 .collect(Collectors.toList()); 197 // get just the labels that fit within the stack 198 } 199 200 private static class Path implements Comparable<Path> { 201 202 public final double score; 203 204 public final Path parent; 205 206 public final List<Label> labels; 207 208 public Path(Label label, double score, Path parent) { 209 this.score = score; 210 this.parent = parent; 211 this.labels = new ArrayList<>(); 212 if (this.parent != null) { 213 this.labels.addAll(this.parent.labels); 214 } 215 this.labels.add(label); 216 } 217 218 @Override 219 public int compareTo(Path that) { 220 return Double.compare(this.score, that.score); 221 } 222 223 } 224 225 /** 226 * Gets the stack size of this model. 227 * @return The stack size. 228 */ 229 public int getStackSize() { 230 return stackSize; 231 } 232 233 /** 234 * Gets the score aggregation function. 235 * @return The score aggregation function. 236 */ 237 public ScoreAggregation getScoreAggregation() { 238 return scoreAggregation; 239 } 240 241 @Override 242 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 243 return model.getTopFeatures(n); 244 } 245 246}