001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence.viterbi;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.Feature;
024import org.tribuo.Model;
025import org.tribuo.Trainer;
026import org.tribuo.classification.Label;
027import org.tribuo.classification.sequence.viterbi.ViterbiModel.ScoreAggregation;
028import org.tribuo.provenance.ModelProvenance;
029import org.tribuo.provenance.TrainerProvenance;
030import org.tribuo.provenance.impl.TrainerProvenanceImpl;
031import org.tribuo.sequence.ImmutableSequenceDataset;
032import org.tribuo.sequence.MutableSequenceDataset;
033import org.tribuo.sequence.SequenceDataset;
034import org.tribuo.sequence.SequenceExample;
035import org.tribuo.sequence.SequenceModel;
036import org.tribuo.sequence.SequenceTrainer;
037
038import java.time.OffsetDateTime;
039import java.util.ArrayList;
040import java.util.List;
041import java.util.Map;
042
043/**
044 * Builds a Viterbi model using the supplied {@link Trainer}.
045 * Has a parameter to control the label features which are added to the features supplied by the data.
046 */
047public final class ViterbiTrainer implements SequenceTrainer<Label> {
048
049    @Config(mandatory = true, description = "Inner trainer for each sequence element.")
050    private Trainer<Label> trainer;
051
052    @Config(mandatory = true, description = "Feature extractor to pull in surrounding label features.")
053    private LabelFeatureExtractor labelFeatureExtractor;
054
055    @Config(mandatory = true, description = "Number of candidate paths.")
056    private int stackSize;
057
058    @Config(mandatory = true, description = "Score aggregation function.")
059    private ScoreAggregation scoreAggregation;
060
061    private int trainInvocationCounter = 0;
062
063    /**
064     * Constructs a ViterbiTrainer wrapping the specified trainer, with an unbounded stack size.
065     * @param trainer The trainer to wrap.
066     * @param labelFeatureExtractor The feature extraction function for labels.
067     * @param scoreAggregation The score aggregation function.
068     */
069    public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor,
070                          ScoreAggregation scoreAggregation) {
071        this(trainer, labelFeatureExtractor, -1, scoreAggregation);
072    }
073
074    /**
075     * Constructs a ViterbiTrainer wrapping the specified trainer.
076     * @param trainer The trainer to wrap.
077     * @param labelFeatureExtractor The feature extraction function for labels.
078     * @param stackSize The stack size.
079     * @param scoreAggregation The score aggregation function.
080     */
081    public ViterbiTrainer(Trainer<Label> trainer, LabelFeatureExtractor labelFeatureExtractor, int stackSize,
082                          ScoreAggregation scoreAggregation) {
083        this.trainer = trainer;
084        this.labelFeatureExtractor = labelFeatureExtractor;
085        this.stackSize = stackSize;
086        this.scoreAggregation = scoreAggregation;
087    }
088
089    /**
090     * For OLCUT.
091     */
092    private ViterbiTrainer() { }
093
094    /**
095     * The viterbi train method is unique because it delegates to a regular
096     * {@link Model} train method, but before it does, it adds features derived
097     * from preceding labels. The pipeline upstream of this call should not care
098     * that these features are being added - that is, we would not want to make
099     * the upstream logic worry about what kind of trainer will be used and have
100     * conditional logic that says to add special label-derived features if
101     * using the ViterbiTrainer. So, these one-of-a-kind unique-in-the-world
102     * label-derived features are generated here and added to the sequence
103     * examples of the passed in dataset. If you pass in a
104     * MutableSequenceDataset, then please be aware that your dataset will be
105     * modified after calling this method and therefore subsequent calls to
106     * other SequenceModel.train methods with your dataset should be avoided. If
107     * you pass in an ImmutableSequenceDataset, then please be aware that your
108     * entire dataset is going to be copied as a MutableSequenceDataset - so
109     * there is a memory penalty.
110     * @param dataset The input dataset.
111     * @param runProvenance Any additional information to record in the provenance.
112     * @return A {@link SequenceModel} using Viterbi wrapped around an inner {@link Model}.
113     */
114    @Override
115    public SequenceModel<Label> train(SequenceDataset<Label> dataset, Map<String, Provenance> runProvenance) {
116        if (dataset.getOutputInfo().getUnknownCount() > 0) {
117            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
118        }
119        // if stack size isn't specified, then we will calculate it based on the
120        // number of unique output values
121        if (stackSize == -1) {
122            stackSize = dataset.getOutputIDInfo().size();
123        }
124
125        // create a copy of the dataset to a mutable one. See note above.
126        if (dataset instanceof ImmutableSequenceDataset) {
127            dataset = new MutableSequenceDataset<>((ImmutableSequenceDataset<Label>) dataset);
128        }
129
130        if (!(dataset instanceof MutableSequenceDataset)) {
131            throw new IllegalArgumentException("unable to handle sub-type of dataset: " + dataset.getClass().getName());
132        }
133
134        for (SequenceExample<Label> sequenceExample : dataset) {
135            List<Label> labels = new ArrayList<>();
136
137            for (Example<Label> example : sequenceExample) {
138                List<Feature> labelFeatures = extractFeatures(labels, (MutableSequenceDataset<Label>) dataset,
139                        1.0);
140                example.addAll(labelFeatures);
141                labels.add(example.getOutput());
142            }
143        }
144
145        TrainerProvenance trainerProvenance = getProvenance();
146        ModelProvenance provenance = new ModelProvenance(ViterbiModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), trainerProvenance, runProvenance);
147        trainInvocationCounter++;
148        Dataset<Label> flatData = dataset.getFlatDataset();
149        Model<Label> model = trainer.train(flatData);
150        return new ViterbiModel("viterbi+" + model.getName(), provenance, model,
151                labelFeatureExtractor, stackSize, scoreAggregation);
152    }
153
154    @Override
155    public int getInvocationCount() {
156        return trainInvocationCounter;
157    }
158
159    private List<Feature> extractFeatures(List<Label> labels,
160                                          MutableSequenceDataset<Label> dataset, double value) {
161        List<Feature> labelFeatures = new ArrayList<>();
162        for (Feature labelFeature : labelFeatureExtractor.extractFeatures(labels, value)) {
163            dataset.getFeatureMap().add(labelFeature.getName(), labelFeature.getValue());
164            labelFeatures.add(labelFeature);
165        }
166        return labelFeatures;
167    }
168
169    @Override
170    public String toString() {
171        return "ViterbiTrainer(innerTrainer=" + trainer.toString() + ",labelFeatureExtractor=" + labelFeatureExtractor.toString() + ")";
172    }
173
174    @Override
175    public TrainerProvenance getProvenance() {
176        return new TrainerProvenanceImpl(this);
177    }
178}