Package org.tribuo.regression.rtree
Class CARTRegressionTrainer
java.lang.Object
org.tribuo.common.tree.AbstractCARTTrainer<Regressor>
org.tribuo.regression.rtree.CARTRegressionTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<org.tribuo.provenance.TrainerProvenance>,DecisionTreeTrainer<Regressor>,org.tribuo.SparseTrainer<Regressor>,org.tribuo.Trainer<Regressor>,org.tribuo.WeightedExamples
A
Trainer that uses an approximation of the CART algorithm to build a decision tree.
Trains an independent tree for each output dimension.
See:
J. Friedman, T. Hastie, & R. Tibshirani. "The Elements of Statistical Learning" Springer 2001. PDF
-
Field Summary
Fields inherited from class org.tribuo.common.tree.AbstractCARTTrainer
MIN_EXAMPLESFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT -
Constructor Summary
ConstructorsConstructorDescriptionCreates a CART trainer.CARTRegressionTrainer(int maxDepth) Creates a CART trainer.CARTRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, RegressorImpurity impurity, long seed) Creates a CART Trainer.CARTRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, RegressorImpurity impurity, long seed) Creates a CART Trainer. -
Method Summary
Modifier and TypeMethodDescriptionorg.tribuo.provenance.TrainerProvenancetoString()train(org.tribuo.Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) train(org.tribuo.Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Methods inherited from class org.tribuo.common.tree.AbstractCARTTrainer
getFractionFeaturesInSplit, getInvocationCount, getMinImpurityDecrease, getUseRandomSplitPoints, postConfig, setInvocationCount, train
-
Constructor Details
-
CARTRegressionTrainer
public CARTRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, RegressorImpurity impurity, long seed) Creates a CART Trainer.- Parameters:
maxDepth- maxDepth The maximum depth of the tree.minChildWeight- minChildWeight The minimum node weight to consider it for a split.minImpurityDecrease- The minimum decrease in impurity necessary to split a node.fractionFeaturesInSplit- fractionFeaturesInSplit The fraction of features available in each split.useRandomSplitPoints- Whether to choose split points for features at random.impurity- impurity The impurity function to use to determine split quality.seed- The RNG seed.
-
CARTRegressionTrainer
public CARTRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, RegressorImpurity impurity, long seed) Creates a CART Trainer.Computes the exact split point.
- Parameters:
maxDepth- maxDepth The maximum depth of the tree.minChildWeight- minChildWeight The minimum node weight to consider it for a split.minImpurityDecrease- The minimum decrease in impurity necessary to split a node.fractionFeaturesInSplit- fractionFeaturesInSplit The fraction of features available in each split.impurity- impurity The impurity function to use to determine split quality.seed- The RNG seed.
-
CARTRegressionTrainer
public CARTRegressionTrainer()Creates a CART trainer.Sets the impurity to the
MeanSquaredError, uses all the features, computes the exact split point, and sets the minimum number of examples in a leaf toAbstractCARTTrainer.MIN_EXAMPLES. -
CARTRegressionTrainer
Creates a CART trainer.Sets the impurity to the
MeanSquaredError, uses all the features, computes the exact split point and sets the minimum number of examples in a leaf toAbstractCARTTrainer.MIN_EXAMPLES.- Parameters:
maxDepth- The maximum depth of the tree.
-
-
Method Details
-
train
public TreeModel<Regressor> train(org.tribuo.Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) - Specified by:
trainin interfaceorg.tribuo.SparseTrainer<Regressor>- Specified by:
trainin interfaceorg.tribuo.Trainer<Regressor>- Overrides:
trainin classAbstractCARTTrainer<Regressor>
-
train
public TreeModel<Regressor> train(org.tribuo.Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) - Specified by:
trainin interfaceorg.tribuo.SparseTrainer<Regressor>- Specified by:
trainin interfaceorg.tribuo.Trainer<Regressor>- Overrides:
trainin classAbstractCARTTrainer<Regressor>
-
toString
-
getProvenance
-