001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import org.tribuo.Dataset; 021import org.tribuo.Trainer; 022import org.tribuo.common.tree.AbstractCARTTrainer; 023import org.tribuo.common.tree.AbstractTrainingNode; 024import org.tribuo.provenance.TrainerProvenance; 025import org.tribuo.provenance.impl.TrainerProvenanceImpl; 026import org.tribuo.regression.Regressor; 027import org.tribuo.regression.rtree.impl.JointRegressorTrainingNode; 028import org.tribuo.regression.rtree.impurity.MeanSquaredError; 029import org.tribuo.regression.rtree.impurity.RegressorImpurity; 030 031/** 032 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree. 033 * <p> 034 * Builds a single tree for all the regression dimensions. 035 * <p> 036 * See: 037 * <pre> 038 * J. Friedman, T. Hastie, & R. Tibshirani. 039 * "The Elements of Statistical Learning" 040 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 041 * </pre> 042 */ 043public class CARTJointRegressionTrainer extends AbstractCARTTrainer<Regressor> { 044 045 /** 046 * Impurity measure used to determine split quality. 047 */ 048 @Config(description="The regression impurity to use.") 049 private RegressorImpurity impurity = new MeanSquaredError(); 050 051 /** 052 * Normalizes the output of each leaf so it sums to one (i.e., is a probability distribution). 053 */ 054 @Config(description="Normalize the output of each leaf so it sums to one.") 055 private boolean normalize = false; 056 057 /** 058 * Creates a CART Trainer. 059 * 060 * @param maxDepth maxDepth The maximum depth of the tree. 061 * @param minChildWeight minChildWeight The minimum node weight to consider it for a split. 062 * @param minImpurityDecrease The minimum decrease in impurity necessary to split a node. 063 * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split. 064 * @param useRandomSplitPoints Whether to choose split points for features at random. 065 * @param impurity impurity The impurity function to use to determine split quality. 066 * @param normalize Normalize the leaves so each output sums to one. 067 * @param seed The seed to use for the RNG. 068 */ 069 public CARTJointRegressionTrainer( 070 int maxDepth, 071 float minChildWeight, 072 float minImpurityDecrease, 073 float fractionFeaturesInSplit, 074 boolean useRandomSplitPoints, 075 RegressorImpurity impurity, 076 boolean normalize, 077 long seed 078 ) { 079 super(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, useRandomSplitPoints, seed); 080 this.impurity = impurity; 081 this.normalize = normalize; 082 postConfig(); 083 } 084 085 /** 086 * Creates a CART Trainer. 087 * <p> 088 * Computes the exact split point. 089 * @param maxDepth maxDepth The maximum depth of the tree. 090 * @param minChildWeight minChildWeight The minimum node weight to consider it for a split. 091 * @param minImpurityDecrease The minimum decrease in impurity necessary to split a node. 092 * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split. 093 * @param impurity impurity The impurity function to use to determine split quality. 094 * @param normalize Normalize the leaves so each output sums to one. 095 * @param seed The seed to use for the RNG. 096 */ 097 public CARTJointRegressionTrainer( 098 int maxDepth, 099 float minChildWeight, 100 float minImpurityDecrease, 101 float fractionFeaturesInSplit, 102 RegressorImpurity impurity, 103 boolean normalize, 104 long seed 105 ) { 106 this(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, false, impurity, normalize, seed); 107 } 108 109 /** 110 * Creates a CART Trainer. 111 * <p> 112 * Sets the impurity to the {@link MeanSquaredError}, computes an arbitrary depth 113 * tree with exact split points using all the features, and does not normalize the outputs. 114 */ 115 public CARTJointRegressionTrainer() { 116 this(Integer.MAX_VALUE, MIN_EXAMPLES, 0.0f, 1.0f, false, new MeanSquaredError(), false, Trainer.DEFAULT_SEED); 117 } 118 119 /** 120 * Creates a CART Trainer. 121 * <p> 122 * Sets the impurity to the {@link MeanSquaredError}, computes the exact split 123 * points using all the features, and does not normalize the outputs. 124 * @param maxDepth The maximum depth of the tree. 125 */ 126 public CARTJointRegressionTrainer(int maxDepth) { 127 this(maxDepth, MIN_EXAMPLES, 0.0f, 1.0f, false, new MeanSquaredError(), false, Trainer.DEFAULT_SEED); 128 } 129 130 /** 131 * Creates a CART Trainer. Sets the impurity to the {@link MeanSquaredError}. 132 * @param maxDepth The maximum depth of the tree. 133 * @param normalize Normalises the leaves so each leaf has a distribution which sums to 1.0. 134 */ 135 public CARTJointRegressionTrainer(int maxDepth, boolean normalize) { 136 this(maxDepth, MIN_EXAMPLES, 0.0f, 1.0f, false, new MeanSquaredError(), normalize, Trainer.DEFAULT_SEED); 137 } 138 139 @Override 140 protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples, 141 AbstractTrainingNode.LeafDeterminer leafDeterminer) { 142 return new JointRegressorTrainingNode(impurity, examples, normalize, leafDeterminer); 143 } 144 145 @Override 146 public String toString() { 147 StringBuilder buffer = new StringBuilder(); 148 149 buffer.append("CARTJointRegressionTrainer(maxDepth="); 150 buffer.append(maxDepth); 151 buffer.append(",minChildWeight="); 152 buffer.append(minChildWeight); 153 buffer.append(",minImpurityDecrease="); 154 buffer.append(minImpurityDecrease); 155 buffer.append(",fractionFeaturesInSplit="); 156 buffer.append(fractionFeaturesInSplit); 157 buffer.append(",useRandomSplitPoints="); 158 buffer.append(useRandomSplitPoints); 159 buffer.append(",impurity="); 160 buffer.append(impurity.toString()); 161 buffer.append(",normalize="); 162 buffer.append(normalize); 163 buffer.append(",seed="); 164 buffer.append(seed); 165 buffer.append(")"); 166 167 return buffer.toString(); 168 } 169 170 @Override 171 public TrainerProvenance getProvenance() { 172 return new TrainerProvenanceImpl(this); 173 } 174}