Class AdaBoostTrainer
java.lang.Object
org.tribuo.classification.ensemble.AdaBoostTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,Trainer<Label>
Implements Adaboost.SAMME one of the more popular algorithms for multiclass boosting.
Based on this paper.
If the trainer implements WeightedExamples then it performs boosting by weighting,
otherwise it uses a weighted bootstrap sample.
See:
J. Zhu, S. Rosset, H. Zou, T. Hastie. "Multi-class Adaboost" 2006.
-
Field Summary
FieldsModifier and TypeFieldDescriptionprotected intprotected SplittableRandomprotected longprotected intFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT -
Constructor Summary
ConstructorsConstructorDescriptionAdaBoostTrainer(Trainer<Label> trainer, int numMembers) Constructs an adaboost trainer using the supplied weak learner trainer and the specified number of boosting rounds.AdaBoostTrainer(Trainer<Label> trainer, int numMembers, long seed) Constructs an adaboost trainer using the supplied weak learner trainer, the specified number of boosting rounds and the supplied seed. -
Method Summary
Modifier and TypeMethodDescriptionintvoidUsed by the OLCUT configuration system, and should not be called by external code.voidsetInvocationCount(int invocationCount) toString()train(Dataset<Label> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) If the trainer implementsWeightedExamplesthen do boosting by weighting, otherwise do boosting by sampling.train(Dataset<Label> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount)
-
Field Details
-
innerTrainer
@Config(mandatory=true, description="The trainer to use to build each weak learner.") protected Trainer<Label> innerTrainer -
numMembers
@Config(mandatory=true, description="The number of ensemble members to train.") protected int numMembers -
seed
-
rng
-
trainInvocationCounter
-
-
Constructor Details
-
AdaBoostTrainer
Constructs an adaboost trainer using the supplied weak learner trainer and the specified number of boosting rounds. Uses the default seed.- Parameters:
trainer- The weak learner trainer.numMembers- The maximum number of boosting rounds.
-
AdaBoostTrainer
Constructs an adaboost trainer using the supplied weak learner trainer, the specified number of boosting rounds and the supplied seed.- Parameters:
trainer- The weak learner trainer.numMembers- The maximum number of boosting rounds.seed- The RNG seed.
-
-
Method Details
-
postConfig
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
toString
-
train
public Model<Label> train(Dataset<Label> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) If the trainer implementsWeightedExamplesthen do boosting by weighting, otherwise do boosting by sampling.- Specified by:
trainin interfaceTrainer<Label>- Parameters:
examples- the data set containing the examples.- Returns:
- A
WeightedEnsembleModel.
-
train
-
getInvocationCount
- Specified by:
getInvocationCountin interfaceTrainer<Label>
-
setInvocationCount
- Specified by:
setInvocationCountin interfaceTrainer<Label>
-
getProvenance
- Specified by:
getProvenancein interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
-