001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence.viterbi;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Feature;
022import org.tribuo.Model;
023import org.tribuo.Prediction;
024import org.tribuo.classification.Label;
025import org.tribuo.provenance.ModelProvenance;
026import org.tribuo.sequence.SequenceDataset;
027import org.tribuo.sequence.SequenceExample;
028import org.tribuo.sequence.SequenceModel;
029
030import java.util.ArrayList;
031import java.util.Collection;
032import java.util.Collections;
033import java.util.Comparator;
034import java.util.HashMap;
035import java.util.List;
036import java.util.Map;
037import java.util.stream.Collectors;
038
039/**
040 * An implementation of a viterbi model.
041 */
042public class ViterbiModel extends SequenceModel<Label> {
043
044    private static final long serialVersionUID = 1L;
045
046    /**
047     * Types of label score aggregation.
048     */
049    public enum ScoreAggregation {
050        /**
051         * Adds the scores.
052         */
053        ADD,
054        /**
055         * Multiplies the scores.
056         */
057        MULTIPLY
058    }
059
060    private final Model<Label> model;
061
062    private final LabelFeatureExtractor labelFeatureExtractor;
063
064    /**
065     * Specifies the maximum number of candidate paths to keep track of. In general, this number
066     * should be higher than the number of possible classifications at any given point in the
067     * sequence. This guarantees that highest-possible scoring sequence will be returned. If,
068     * however, the number of possible classifications is quite high and/or you are concerned about
069     * throughput performance, then you may want to reduce the number of candidate paths to
070     * maintain.
071     */
072    private final int stackSize;
073
074    /**
075     * Specifies the score aggregation algorithm.
076     */
077    private final ScoreAggregation scoreAggregation;
078
079    ViterbiModel(String name, ModelProvenance description,
080                        Model<Label> model, LabelFeatureExtractor labelFeatureExtractor, int stackSize, ScoreAggregation scoreAggregation) {
081        super(name, description, model.getFeatureIDMap(), model.getOutputIDInfo());
082        this.model = model;
083        this.labelFeatureExtractor = labelFeatureExtractor;
084        this.stackSize = stackSize;
085        this.scoreAggregation = scoreAggregation;
086    }
087
088    @Override
089    public List<List<Prediction<Label>>> predict(SequenceDataset<Label> examples) {
090        List<List<Prediction<Label>>> predictions = new ArrayList<>();
091        for (SequenceExample<Label> e : examples) {
092            predictions.add(predict(e));
093        }
094        return predictions;
095    }
096
097    @Override
098    public List<Prediction<Label>> predict(SequenceExample<Label> examples) {
099        if (stackSize == 1) {
100            List<Label> labels = new ArrayList<>();
101            List<Prediction<Label>> returnValues = new ArrayList<>();
102            for (Example<Label> example : examples) {
103                List<Feature> labelFeatures = extractFeatures(labels);
104                example.addAll(labelFeatures);
105                Prediction<Label> prediction = model.predict(example);
106                labels.add(prediction.getOutput());
107                returnValues.add(prediction);
108            }
109            return returnValues;
110        } else {
111            return viterbi(examples);
112        }
113
114    }
115
116    private List<Feature> extractFeatures(List<Label> labels) {
117        List<Feature> labelFeatures = new ArrayList<>();
118        for (Feature labelFeature : labelFeatureExtractor.extractFeatures(labels, 1.0)) {
119            int id = featureIDMap.getID(labelFeature.getName());
120            if (id > -1) {
121                labelFeatures.add(labelFeature);
122            }
123        }
124        return labelFeatures;
125    }
126
127    /**
128     * This implementation of Viterbi requires at most stackSize * sequenceLength calls to the
129     * classifier. If this proves to be too expensive, then consider using a smaller stack size.
130     *
131     * @param examples a sequence-worth of features. Each {@code List<Feature>} in features should correspond to
132     *                 all of the features for a given element in a sequence to be classified.
133     * @return a list of Predictions - one for each member of the sequence.
134     * @see LabelFeatureExtractor
135     */
136    private List<Prediction<Label>> viterbi(SequenceExample<Label> examples) {
137        // find the best paths through the label lattice
138        Collection<Path> paths = null;
139        int[] numUsed = new int[examples.size()];
140        int i = 0;
141        for (Example<Label> example : examples) {
142            // if this is the first instance, start new paths for each label
143            if (paths == null) {
144                paths = new ArrayList<>();
145                Prediction<Label> prediction = this.model.predict(example);
146                numUsed[i] = prediction.getNumActiveFeatures();
147                Map<String, Label> distribution = prediction.getOutputScores();
148                for (Label label : this.getTopLabels(distribution)) {
149                    paths.add(new Path(label, label.getScore(), null));
150                }
151            } else {
152                // for later instances, find the best previous path for each label
153                Map<Label, Path> maxPaths = new HashMap<>();
154                for (Path path : paths) {
155                    Example<Label> clonedExample = example.copy();
156                    List<Label> previousLabels = new ArrayList<>(path.labels);
157                    List<Feature> labelFeatures = extractFeatures(previousLabels);
158                    clonedExample.addAll(labelFeatures);
159                    Prediction<Label> prediction = this.model.predict(clonedExample);
160                    // TODO this isn't quite correct as it includes label features.
161                    numUsed[i] = prediction.getNumActiveFeatures();
162                    Map<String, Label> distribution = prediction.getOutputScores();
163
164                    for (Label label : this.getTopLabels(distribution)) {
165                        double labelScore = label.getScore();
166                        double score = this.scoreAggregation == ScoreAggregation.ADD ? path.score + labelScore : path.score * labelScore;
167                        Path maxPath = maxPaths.get(label);
168                        if (maxPath == null || score > maxPath.score) {
169                            maxPaths.put(label, new Path(label, score, path));
170                        }
171                    }
172                }
173                paths = maxPaths.values();
174            }
175            i++;
176        }
177
178        Path maxPath = Collections.max(paths);
179
180        ArrayList<Prediction<Label>> output = new ArrayList<>();
181
182        for (int j = 0; j < examples.size(); j++) {
183            Example<Label> e = examples.get(j);
184            output.add(new Prediction<>(maxPath.labels.get(j), numUsed[j], e));
185        }
186
187        return output;
188    }
189
190    protected List<Label> getTopLabels(Map<String, Label> distribution) {
191        return getTopLabels(distribution, this.stackSize);
192    }
193
194    protected static List<Label> getTopLabels(Map<String, Label> distribution, int stackSize) {
195        return distribution.values().stream().sorted(Comparator.comparingDouble(Label::getScore).reversed()).limit(stackSize)
196                .collect(Collectors.toList());
197        // get just the labels that fit within the stack
198    }
199
200    private static class Path implements Comparable<Path> {
201
202        public final double score;
203
204        public final Path parent;
205
206        public final List<Label> labels;
207
208        public Path(Label label, double score, Path parent) {
209            this.score = score;
210            this.parent = parent;
211            this.labels = new ArrayList<>();
212            if (this.parent != null) {
213                this.labels.addAll(this.parent.labels);
214            }
215            this.labels.add(label);
216        }
217
218        @Override
219        public int compareTo(Path that) {
220            return Double.compare(this.score, that.score);
221        }
222
223    }
224
225    /**
226     * Gets the stack size of this model.
227     * @return The stack size.
228     */
229    public int getStackSize() {
230        return stackSize;
231    }
232
233    /**
234     * Gets the score aggregation function.
235     * @return The score aggregation function.
236     */
237    public ScoreAggregation getScoreAggregation() {
238        return scoreAggregation;
239    }
240
241    @Override
242    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
243        return model.getTopFeatures(n);
244    }
245
246}