Class AdaBoostTrainer

java.lang.Object
org.tribuo.classification.ensemble.AdaBoostTrainer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, Trainer<Label>

public class AdaBoostTrainer extends Object implements Trainer<Label>
Implements Adaboost.SAMME one of the more popular algorithms for multiclass boosting. Based on this paper.

If the trainer implements WeightedExamples then it performs boosting by weighting, otherwise it uses a weighted bootstrap sample.

See:

 J. Zhu, S. Rosset, H. Zou, T. Hastie.
 "Multi-class Adaboost"
 2006.
 
  • Field Details

    • innerTrainer

      @Config(mandatory=true, description="The trainer to use to build each weak learner.") protected Trainer<Label> innerTrainer
    • numMembers

      @Config(mandatory=true, description="The number of ensemble members to train.") protected int numMembers
    • seed

      @Config(mandatory=true, description="The seed for the RNG.") protected long seed
    • rng

      protected SplittableRandom rng
    • trainInvocationCounter

      protected int trainInvocationCounter
  • Constructor Details

    • AdaBoostTrainer

      public AdaBoostTrainer(Trainer<Label> trainer, int numMembers)
      Constructs an adaboost trainer using the supplied weak learner trainer and the specified number of boosting rounds. Uses the default seed.
      Parameters:
      trainer - The weak learner trainer.
      numMembers - The maximum number of boosting rounds.
    • AdaBoostTrainer

      public AdaBoostTrainer(Trainer<Label> trainer, int numMembers, long seed)
      Constructs an adaboost trainer using the supplied weak learner trainer, the specified number of boosting rounds and the supplied seed.
      Parameters:
      trainer - The weak learner trainer.
      numMembers - The maximum number of boosting rounds.
      seed - The RNG seed.
  • Method Details