001/* 002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.ensemble; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import org.tribuo.Dataset; 023import org.tribuo.Example; 024import org.tribuo.ImmutableDataset; 025import org.tribuo.ImmutableFeatureMap; 026import org.tribuo.ImmutableOutputInfo; 027import org.tribuo.Model; 028import org.tribuo.Prediction; 029import org.tribuo.Trainer; 030import org.tribuo.WeightedExamples; 031import org.tribuo.classification.Label; 032import org.tribuo.dataset.DatasetView; 033import org.tribuo.ensemble.WeightedEnsembleModel; 034import org.tribuo.provenance.EnsembleModelProvenance; 035import org.tribuo.provenance.TrainerProvenance; 036import org.tribuo.provenance.impl.TrainerProvenanceImpl; 037import org.tribuo.util.Util; 038 039import java.time.OffsetDateTime; 040import java.util.ArrayList; 041import java.util.Arrays; 042import java.util.List; 043import java.util.Map; 044import java.util.SplittableRandom; 045import java.util.logging.Level; 046import java.util.logging.Logger; 047 048/** 049 * Implements Adaboost.SAMME one of the more popular algorithms for multiclass boosting. 050 * Based on <a href="https://web.stanford.edu/~hastie/Papers/samme.pdf">this paper</a>. 051 * <p> 052 * If the trainer implements {@link WeightedExamples} then it performs boosting by weighting, 053 * otherwise it uses a weighted bootstrap sample. 054 * <p> 055 * See: 056 * <pre> 057 * J. Zhu, S. Rosset, H. Zou, T. Hastie. 058 * "Multi-class Adaboost" 059 * 2006. 060 * </pre> 061 */ 062public class AdaBoostTrainer implements Trainer<Label> { 063 064 private static final Logger logger = Logger.getLogger(AdaBoostTrainer.class.getName()); 065 066 @Config(mandatory=true, description="The trainer to use to build each weak learner.") 067 protected Trainer<Label> innerTrainer; 068 069 @Config(mandatory=true, description="The number of ensemble members to train.") 070 protected int numMembers; 071 072 @Config(mandatory=true, description="The seed for the RNG.") 073 protected long seed; 074 075 protected SplittableRandom rng; 076 077 protected int trainInvocationCounter; 078 079 /** 080 * For the OLCUT configuration system. 081 */ 082 private AdaBoostTrainer() { } 083 084 /** 085 * Constructs an adaboost trainer using the supplied weak learner trainer and the specified number of 086 * boosting rounds. Uses the default seed. 087 * @param trainer The weak learner trainer. 088 * @param numMembers The maximum number of boosting rounds. 089 */ 090 public AdaBoostTrainer(Trainer<Label> trainer, int numMembers) { 091 this(trainer, numMembers, Trainer.DEFAULT_SEED); 092 } 093 094 /** 095 * Constructs an adaboost trainer using the supplied weak learner trainer, the specified number of 096 * boosting rounds and the supplied seed. 097 * @param trainer The weak learner trainer. 098 * @param numMembers The maximum number of boosting rounds. 099 * @param seed The RNG seed. 100 */ 101 public AdaBoostTrainer(Trainer<Label> trainer, int numMembers, long seed) { 102 this.innerTrainer = trainer; 103 this.numMembers = numMembers; 104 this.seed = seed; 105 postConfig(); 106 } 107 108 /** 109 * Used by the OLCUT configuration system, and should not be called by external code. 110 */ 111 @Override 112 public synchronized void postConfig() { 113 this.rng = new SplittableRandom(seed); 114 } 115 116 @Override 117 public String toString() { 118 StringBuilder buffer = new StringBuilder(); 119 120 buffer.append("AdaBoostTrainer("); 121 buffer.append("innerTrainer="); 122 buffer.append(innerTrainer.toString()); 123 buffer.append(",numMembers="); 124 buffer.append(numMembers); 125 buffer.append(",seed="); 126 buffer.append(seed); 127 buffer.append(")"); 128 129 return buffer.toString(); 130 } 131 132 /** 133 * If the trainer implements {@link WeightedExamples} then do boosting by weighting, 134 * otherwise do boosting by sampling. 135 * @param examples the data set containing the examples. 136 * @return A {@link WeightedEnsembleModel}. 137 */ 138 @Override 139 public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) { 140 return(train(examples, runProvenance, INCREMENT_INVOCATION_COUNT)); 141 } 142 143 @Override 144 public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) { 145 if (examples.getOutputInfo().getUnknownCount() > 0) { 146 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 147 } 148 // Creates a new RNG, adds one to the invocation count. 149 SplittableRandom localRNG; 150 TrainerProvenance trainerProvenance; 151 synchronized(this) { 152 if(invocationCount != INCREMENT_INVOCATION_COUNT) { 153 setInvocationCount(invocationCount); 154 } 155 localRNG = rng.split(); 156 trainerProvenance = getProvenance(); 157 trainInvocationCounter++; 158 } 159 boolean weighted = innerTrainer instanceof WeightedExamples; 160 ImmutableFeatureMap featureIDs = examples.getFeatureIDMap(); 161 ImmutableOutputInfo<Label> labelIDs = examples.getOutputIDInfo(); 162 int numClasses = labelIDs.size(); 163 logger.log(Level.INFO,"NumClasses = " + numClasses); 164 ArrayList<Model<Label>> models = new ArrayList<>(); 165 float[] modelWeights = new float[numMembers]; 166 float[] exampleWeights = Util.generateUniformFloatVector(examples.size(), 1.0f/examples.size()); 167 if (weighted) { 168 logger.info("Using weighted Adaboost."); 169 examples = ImmutableDataset.copyDataset(examples); 170 for (int i = 0; i < examples.size(); i++) { 171 Example<Label> e = examples.getExample(i); 172 e.setWeight(exampleWeights[i]); 173 } 174 } else { 175 logger.info("Using sampling Adaboost."); 176 } 177 for (int i = 0; i < numMembers; i++) { 178 logger.info("Building model " + i); 179 Model<Label> newModel; 180 if (weighted) { 181 newModel = innerTrainer.train(examples); 182 } else { 183 DatasetView<Label> bag = DatasetView.createWeightedBootstrapView(examples,examples.size(),localRNG.nextLong(),exampleWeights,featureIDs,labelIDs); 184 newModel = innerTrainer.train(bag); 185 } 186 187 // 188 // Score this model 189 List<Prediction<Label>> predictions = newModel.predict(examples); 190 float accuracy = accuracy(predictions,examples,exampleWeights); 191 float error = 1.0f - accuracy; 192 float alpha = (float) (Math.log(accuracy/error) + Math.log(numClasses - 1)); 193 models.add(newModel); 194 modelWeights[i] = alpha; 195 if ((accuracy + 1e-10) > 1.0) { 196 // 197 // Perfect accuracy, can no longer boost. 198 float[] newModelWeights = Arrays.copyOf(modelWeights, models.size()); 199 newModelWeights[models.size()-1] = 1.0f; //Set the last weight to 1, as it's infinity. 200 logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model."); 201 logger.log(Level.FINE, "Model weights:"); 202 Util.logVector(logger, Level.FINE, newModelWeights); 203 EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models)); 204 return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),newModelWeights); 205 } 206 207 // 208 // Update the weights 209 for (int j = 0; j < predictions.size(); j++) { 210 if (!predictions.get(j).getOutput().equals(examples.getExample(j).getOutput())) { 211 exampleWeights[j] *= Math.exp(alpha); 212 } 213 } 214 Util.inplaceNormalizeToDistribution(exampleWeights); 215 if (weighted) { 216 for (int j = 0; j < examples.size(); j++) { 217 examples.getExample(j).setWeight(exampleWeights[j]); 218 } 219 } 220 } 221 logger.log(Level.FINE, "Model weights:"); 222 Util.logVector(logger, Level.FINE, modelWeights); 223 EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models)); 224 return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),modelWeights); 225 } 226 227 @Override 228 public int getInvocationCount() { 229 return trainInvocationCounter; 230 } 231 232 @Override 233 public synchronized void setInvocationCount(int invocationCount){ 234 if(invocationCount < 0){ 235 throw new IllegalArgumentException("The supplied invocationCount is less than zero."); 236 } 237 238 rng = new SplittableRandom(seed); 239 240 for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){ 241 SplittableRandom localRNG = rng.split(); 242 } 243 244 } 245 246 /** 247 * Compute the accuracy of a set of predictions. 248 * @param predictions The base learner predictions. 249 * @param examples The training examples. 250 * @param weights The current example weights. 251 * @return The accuracy. 252 */ 253 private static float accuracy(List<Prediction<Label>> predictions, Dataset<Label> examples, float[] weights) { 254 float correctSum = 0; 255 float total = 0; 256 for (int i = 0; i < predictions.size(); i++) { 257 if (predictions.get(i).getOutput().equals(examples.getExample(i).getOutput())) { 258 correctSum += weights[i]; 259 } 260 total += weights[i]; 261 } 262 263 logger.log(Level.FINEST, "Correct count = " + correctSum + " size = " + examples.size()); 264 265 return correctSum / total; 266 } 267 268 @Override 269 public TrainerProvenance getProvenance() { 270 return new TrainerProvenanceImpl(this); 271 } 272}