001/*
002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.ensemble;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import org.tribuo.Dataset;
023import org.tribuo.Example;
024import org.tribuo.ImmutableDataset;
025import org.tribuo.ImmutableFeatureMap;
026import org.tribuo.ImmutableOutputInfo;
027import org.tribuo.Model;
028import org.tribuo.Prediction;
029import org.tribuo.Trainer;
030import org.tribuo.WeightedExamples;
031import org.tribuo.classification.Label;
032import org.tribuo.dataset.DatasetView;
033import org.tribuo.ensemble.WeightedEnsembleModel;
034import org.tribuo.provenance.EnsembleModelProvenance;
035import org.tribuo.provenance.TrainerProvenance;
036import org.tribuo.provenance.impl.TrainerProvenanceImpl;
037import org.tribuo.util.Util;
038
039import java.time.OffsetDateTime;
040import java.util.ArrayList;
041import java.util.Arrays;
042import java.util.List;
043import java.util.Map;
044import java.util.SplittableRandom;
045import java.util.logging.Level;
046import java.util.logging.Logger;
047
048/**
049 * Implements Adaboost.SAMME one of the more popular algorithms for multiclass boosting.
050 * Based on  <a href="https://web.stanford.edu/~hastie/Papers/samme.pdf">this paper</a>.
051 * <p>
052 * If the trainer implements {@link WeightedExamples} then it performs boosting by weighting,
053 * otherwise it uses a weighted bootstrap sample.
054 * <p>
055 * See:
056 * <pre>
057 * J. Zhu, S. Rosset, H. Zou, T. Hastie.
058 * "Multi-class Adaboost"
059 * 2006.
060 * </pre>
061 */
062public class AdaBoostTrainer implements Trainer<Label> {
063
064    private static final Logger logger = Logger.getLogger(AdaBoostTrainer.class.getName());
065    
066    @Config(mandatory=true, description="The trainer to use to build each weak learner.")
067    protected Trainer<Label> innerTrainer;
068
069    @Config(mandatory=true, description="The number of ensemble members to train.")
070    protected int numMembers;
071
072    @Config(mandatory=true, description="The seed for the RNG.")
073    protected long seed;
074
075    protected SplittableRandom rng;
076
077    protected int trainInvocationCounter;
078
079    /**
080     * For the OLCUT configuration system.
081     */
082    private AdaBoostTrainer() { }
083
084    /**
085     * Constructs an adaboost trainer using the supplied weak learner trainer and the specified number of
086     * boosting rounds. Uses the default seed.
087     * @param trainer The weak learner trainer.
088     * @param numMembers The maximum number of boosting rounds.
089     */
090    public AdaBoostTrainer(Trainer<Label> trainer, int numMembers) {
091        this(trainer, numMembers, Trainer.DEFAULT_SEED);
092    }
093
094    /**
095     * Constructs an adaboost trainer using the supplied weak learner trainer, the specified number of
096     * boosting rounds and the supplied seed.
097     * @param trainer The weak learner trainer.
098     * @param numMembers The maximum number of boosting rounds.
099     * @param seed The RNG seed.
100     */
101    public AdaBoostTrainer(Trainer<Label> trainer, int numMembers, long seed) {
102        this.innerTrainer = trainer;
103        this.numMembers = numMembers;
104        this.seed = seed;
105        postConfig();
106    }
107
108    /**
109     * Used by the OLCUT configuration system, and should not be called by external code.
110     */
111    @Override
112    public synchronized void postConfig() {
113        this.rng = new SplittableRandom(seed);
114    }
115
116    @Override
117    public String toString() {
118        StringBuilder buffer = new StringBuilder();
119
120        buffer.append("AdaBoostTrainer(");
121        buffer.append("innerTrainer=");
122        buffer.append(innerTrainer.toString());
123        buffer.append(",numMembers=");
124        buffer.append(numMembers);
125        buffer.append(",seed=");
126        buffer.append(seed);
127        buffer.append(")");
128
129        return buffer.toString();
130    }
131
132    /**
133     * If the trainer implements {@link WeightedExamples} then do boosting by weighting,
134     * otherwise do boosting by sampling.
135     * @param examples the data set containing the examples.
136     * @return A {@link WeightedEnsembleModel}.
137     */
138    @Override
139    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
140        return(train(examples, runProvenance, INCREMENT_INVOCATION_COUNT));
141    }
142
143    @Override
144    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
145        if (examples.getOutputInfo().getUnknownCount() > 0) {
146            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
147        }
148        // Creates a new RNG, adds one to the invocation count.
149        SplittableRandom localRNG;
150        TrainerProvenance trainerProvenance;
151        synchronized(this) {
152            if(invocationCount != INCREMENT_INVOCATION_COUNT) {
153                setInvocationCount(invocationCount);
154            }
155            localRNG = rng.split();
156            trainerProvenance = getProvenance();
157            trainInvocationCounter++;
158        }
159        boolean weighted = innerTrainer instanceof WeightedExamples;
160        ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
161        ImmutableOutputInfo<Label> labelIDs = examples.getOutputIDInfo();
162        int numClasses = labelIDs.size();
163        logger.log(Level.INFO,"NumClasses = " + numClasses);
164        ArrayList<Model<Label>> models = new ArrayList<>();
165        float[] modelWeights = new float[numMembers];
166        float[] exampleWeights = Util.generateUniformFloatVector(examples.size(), 1.0f/examples.size());
167        if (weighted) {
168            logger.info("Using weighted Adaboost.");
169            examples = ImmutableDataset.copyDataset(examples);
170            for (int i = 0; i < examples.size(); i++) {
171                Example<Label> e = examples.getExample(i);
172                e.setWeight(exampleWeights[i]);
173            }
174        } else {
175            logger.info("Using sampling Adaboost.");
176        }
177        for (int i = 0; i < numMembers; i++) {
178            logger.info("Building model " + i);
179            Model<Label> newModel;
180            if (weighted) {
181                newModel = innerTrainer.train(examples);
182            } else {
183                DatasetView<Label> bag = DatasetView.createWeightedBootstrapView(examples,examples.size(),localRNG.nextLong(),exampleWeights,featureIDs,labelIDs);
184                newModel = innerTrainer.train(bag);
185            }
186
187            //
188            // Score this model
189            List<Prediction<Label>> predictions = newModel.predict(examples);
190            float accuracy = accuracy(predictions,examples,exampleWeights);
191            float error = 1.0f - accuracy;
192            float alpha = (float) (Math.log(accuracy/error) + Math.log(numClasses - 1));
193            models.add(newModel);
194            modelWeights[i] = alpha;
195            if ((accuracy + 1e-10) > 1.0) {
196                //
197                // Perfect accuracy, can no longer boost.
198                float[] newModelWeights = Arrays.copyOf(modelWeights, models.size());
199                newModelWeights[models.size()-1] = 1.0f; //Set the last weight to 1, as it's infinity.
200                logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model.");
201                logger.log(Level.FINE, "Model weights:");
202                Util.logVector(logger, Level.FINE, newModelWeights);
203                EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
204                return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),newModelWeights);
205            }
206
207            //
208            // Update the weights
209            for (int j = 0; j < predictions.size(); j++) {
210                if (!predictions.get(j).getOutput().equals(examples.getExample(j).getOutput())) {
211                    exampleWeights[j] *= Math.exp(alpha);
212                }
213            }
214            Util.inplaceNormalizeToDistribution(exampleWeights);
215            if (weighted) {
216                for (int j = 0; j < examples.size(); j++) {
217                    examples.getExample(j).setWeight(exampleWeights[j]);
218                }
219            }
220        }
221        logger.log(Level.FINE, "Model weights:");
222        Util.logVector(logger, Level.FINE, modelWeights);
223        EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
224        return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),modelWeights);
225    }
226
227    @Override
228    public int getInvocationCount() {
229        return trainInvocationCounter;
230    }
231
232    @Override
233    public synchronized void setInvocationCount(int invocationCount){
234        if(invocationCount < 0){
235            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
236        }
237
238        rng = new SplittableRandom(seed);
239
240        for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
241            SplittableRandom localRNG = rng.split();
242        }
243
244    }
245
246    /**
247     * Compute the accuracy of a set of predictions.
248     * @param predictions The base learner predictions.
249     * @param examples The training examples.
250     * @param weights The current example weights.
251     * @return The accuracy.
252     */
253    private static float accuracy(List<Prediction<Label>> predictions, Dataset<Label> examples, float[] weights) {
254        float correctSum = 0;
255        float total = 0;
256        for (int i = 0; i < predictions.size(); i++) {
257            if (predictions.get(i).getOutput().equals(examples.getExample(i).getOutput())) {
258                correctSum += weights[i];
259            }
260            total += weights[i];
261        }
262
263        logger.log(Level.FINEST, "Correct count = " + correctSum + " size = " + examples.size());
264
265        return correctSum / total;
266    }
267
268    @Override
269    public TrainerProvenance getProvenance() {
270        return new TrainerProvenanceImpl(this);
271    }
272}