001/* 002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Trainer; 026import org.tribuo.common.tree.AbstractCARTTrainer; 027import org.tribuo.common.tree.AbstractTrainingNode; 028import org.tribuo.common.tree.Node; 029import org.tribuo.common.tree.TreeModel; 030import org.tribuo.provenance.ModelProvenance; 031import org.tribuo.provenance.TrainerProvenance; 032import org.tribuo.provenance.impl.TrainerProvenanceImpl; 033import org.tribuo.regression.Regressor; 034import org.tribuo.regression.rtree.impl.RegressorTrainingNode; 035import org.tribuo.regression.rtree.impl.RegressorTrainingNode.InvertedData; 036import org.tribuo.regression.rtree.impurity.MeanSquaredError; 037import org.tribuo.regression.rtree.impurity.RegressorImpurity; 038import org.tribuo.util.Util; 039 040import java.time.OffsetDateTime; 041import java.util.ArrayDeque; 042import java.util.Deque; 043import java.util.HashMap; 044import java.util.List; 045import java.util.Map; 046import java.util.Set; 047import java.util.SplittableRandom; 048 049/** 050 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree. 051 * Trains an independent tree for each output dimension. 052 * <p> 053 * See: 054 * <pre> 055 * J. Friedman, T. Hastie, & R. Tibshirani. 056 * "The Elements of Statistical Learning" 057 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 058 * </pre> 059 */ 060public final class CARTRegressionTrainer extends AbstractCARTTrainer<Regressor> { 061 062 /** 063 * Impurity measure used to determine split quality. 064 */ 065 @Config(description="Regression impurity measure used to determine split quality.") 066 private RegressorImpurity impurity = new MeanSquaredError(); 067 068 /** 069 * Creates a CART Trainer. 070 * 071 * @param maxDepth maxDepth The maximum depth of the tree. 072 * @param minChildWeight minChildWeight The minimum node weight to consider it for a split. 073 * @param minImpurityDecrease The minimum decrease in impurity necessary to split a node. 074 * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split. 075 * @param useRandomSplitPoints Whether to choose split points for features at random. 076 * @param impurity impurity The impurity function to use to determine split quality. 077 * @param seed The RNG seed. 078 */ 079 public CARTRegressionTrainer( 080 int maxDepth, 081 float minChildWeight, 082 float minImpurityDecrease, 083 float fractionFeaturesInSplit, 084 boolean useRandomSplitPoints, 085 RegressorImpurity impurity, 086 long seed 087 ) { 088 super(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, useRandomSplitPoints, seed); 089 this.impurity = impurity; 090 postConfig(); 091 } 092 093 /** 094 * Creates a CART Trainer. 095 * <p> 096 * Computes the exact split point. 097 * @param maxDepth maxDepth The maximum depth of the tree. 098 * @param minChildWeight minChildWeight The minimum node weight to consider it for a split. 099 * @param minImpurityDecrease The minimum decrease in impurity necessary to split a node. 100 * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split. 101 * @param impurity impurity The impurity function to use to determine split quality. 102 * @param seed The RNG seed. 103 */ 104 public CARTRegressionTrainer( 105 int maxDepth, 106 float minChildWeight, 107 float minImpurityDecrease, 108 float fractionFeaturesInSplit, 109 RegressorImpurity impurity, 110 long seed 111 ) { 112 this(maxDepth,minChildWeight,minImpurityDecrease,fractionFeaturesInSplit,false,impurity,seed); 113 } 114 115 /** 116 * Creates a CART trainer. 117 * <p> 118 * Sets the impurity to the {@link MeanSquaredError}, uses 119 * all the features, computes the exact split point, and 120 * sets the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}. 121 */ 122 public CARTRegressionTrainer() { 123 this(Integer.MAX_VALUE); 124 } 125 126 /** 127 * Creates a CART trainer. 128 * <p> 129 * Sets the impurity to the {@link MeanSquaredError}, uses 130 * all the features, computes the exact split point and sets 131 * the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}. 132 * @param maxDepth The maximum depth of the tree. 133 */ 134 public CARTRegressionTrainer(int maxDepth) { 135 this(maxDepth, MIN_EXAMPLES, 0.0f, 1.0f, false, new MeanSquaredError(), Trainer.DEFAULT_SEED); 136 } 137 138 @Override 139 protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples, 140 AbstractTrainingNode.LeafDeterminer leafDeterminer) { 141 throw new IllegalStateException("Shouldn't reach here."); 142 } 143 144 @Override 145 public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) { 146 return train(examples, runProvenance, INCREMENT_INVOCATION_COUNT); 147 } 148 149 @Override 150 public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) { 151 if (examples.getOutputInfo().getUnknownCount() > 0) { 152 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 153 } 154 // Creates a new RNG, adds one to the invocation count. 155 SplittableRandom localRNG; 156 TrainerProvenance trainerProvenance; 157 synchronized(this) { 158 if(invocationCount != INCREMENT_INVOCATION_COUNT) { 159 setInvocationCount(invocationCount); 160 } 161 localRNG = rng.split(); 162 trainerProvenance = getProvenance(); 163 trainInvocationCounter++; 164 } 165 166 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 167 ImmutableOutputInfo<Regressor> outputIDInfo = examples.getOutputIDInfo(); 168 Set<Regressor> domain = outputIDInfo.getDomain(); 169 170 int numFeaturesInSplit = Math.min(Math.round(fractionFeaturesInSplit * featureIDMap.size()),featureIDMap.size()); 171 int[] indices; 172 int[] originalIndices = new int[featureIDMap.size()]; 173 for (int i = 0; i < originalIndices.length; i++) { 174 originalIndices[i] = i; 175 } 176 if (numFeaturesInSplit != featureIDMap.size()) { 177 indices = new int[numFeaturesInSplit]; 178 } else { 179 indices = originalIndices; 180 } 181 182 float weightSum = 0.0f; 183 for (Example<Regressor> e : examples) { 184 weightSum += e.getWeight(); 185 } 186 float scaledMinImpurityDecrease = getMinImpurityDecrease() * weightSum; 187 AbstractTrainingNode.LeafDeterminer leafDeterminer = new AbstractTrainingNode.LeafDeterminer(maxDepth, 188 minChildWeight, scaledMinImpurityDecrease); 189 190 InvertedData data = RegressorTrainingNode.invertData(examples); 191 192 Map<String, Node<Regressor>> nodeMap = new HashMap<>(); 193 for (Regressor r : domain) { 194 String dimName = r.getNames()[0]; 195 int dimIdx = outputIDInfo.getID(r); 196 197 AbstractTrainingNode<Regressor> root = new RegressorTrainingNode(impurity,data,dimIdx,dimName, 198 examples.size(),featureIDMap,outputIDInfo, leafDeterminer); 199 Deque<AbstractTrainingNode<Regressor>> queue = new ArrayDeque<>(); 200 queue.add(root); 201 202 while (!queue.isEmpty()) { 203 AbstractTrainingNode<Regressor> node = queue.poll(); 204 if ((node.getImpurity() > 0.0) && (node.getDepth() < maxDepth) && 205 (node.getWeightSum() >= minChildWeight)) { 206 if (numFeaturesInSplit != featureIDMap.size()) { 207 Util.randpermInPlace(originalIndices, localRNG); 208 System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit); 209 } 210 List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices, localRNG, 211 getUseRandomSplitPoints()); 212 // Use the queue as a stack to improve cache locality. 213 for (AbstractTrainingNode<Regressor> newNode : nodes) { 214 queue.addFirst(newNode); 215 } 216 } 217 } 218 219 nodeMap.put(dimName,root.convertTree()); 220 } 221 222 ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 223 return new IndependentRegressionTreeModel("cart-tree", provenance, featureIDMap, outputIDInfo, false, nodeMap); 224 } 225 226 @Override 227 public String toString() { 228 StringBuilder buffer = new StringBuilder(); 229 230 buffer.append("CARTRegressionTrainer(maxDepth="); 231 buffer.append(maxDepth); 232 buffer.append(",minChildWeight="); 233 buffer.append(minChildWeight); 234 buffer.append(",minImpurityDecrease="); 235 buffer.append(minImpurityDecrease); 236 buffer.append(",fractionFeaturesInSplit="); 237 buffer.append(fractionFeaturesInSplit); 238 buffer.append(",useRandomSplitPoints="); 239 buffer.append(useRandomSplitPoints); 240 buffer.append(",impurity="); 241 buffer.append(impurity.toString()); 242 buffer.append(",seed="); 243 buffer.append(seed); 244 buffer.append(")"); 245 246 return buffer.toString(); 247 } 248 249 @Override 250 public TrainerProvenance getProvenance() { 251 return new TrainerProvenanceImpl(this); 252 } 253}