001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impl; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.common.tree.AbstractTrainingNode; 025import org.tribuo.common.tree.LeafNode; 026import org.tribuo.common.tree.Node; 027import org.tribuo.common.tree.SplitNode; 028import org.tribuo.common.tree.impl.IntArrayContainer; 029import org.tribuo.math.la.SparseVector; 030import org.tribuo.math.la.VectorTuple; 031import org.tribuo.regression.ImmutableRegressionInfo; 032import org.tribuo.regression.Regressor; 033import org.tribuo.regression.Regressor.DimensionTuple; 034import org.tribuo.regression.rtree.impurity.RegressorImpurity; 035import org.tribuo.regression.rtree.impurity.RegressorImpurity.ImpurityTuple; 036import org.tribuo.util.Util; 037 038import java.io.IOException; 039import java.io.NotSerializableException; 040import java.util.ArrayList; 041import java.util.Collections; 042import java.util.List; 043import java.util.SplittableRandom; 044import java.util.logging.Logger; 045 046/** 047 * A decision tree node used at training time. 048 * Contains a list of the example indices currently found in this node, 049 * the current impurity and a bunch of other statistics. 050 */ 051public class RegressorTrainingNode extends AbstractTrainingNode<Regressor> { 052 private static final long serialVersionUID = 1L; 053 054 private static final Logger logger = Logger.getLogger(RegressorTrainingNode.class.getName()); 055 056 private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 057 private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 058 059 private transient ArrayList<TreeFeature> data; 060 061 private final ImmutableOutputInfo<Regressor> labelIDMap; 062 063 private final ImmutableFeatureMap featureIDMap; 064 065 private final RegressorImpurity impurity; 066 067 private final int[] indices; 068 069 private final float[] targets; 070 071 private final float[] weights; 072 073 private final String dimName; 074 075 private final float weightSum; 076 077 /** 078 * Constructs a tree training node for regression problems. 079 * @param impurity The impurity function. 080 * @param tuple The data tuple. 081 * @param dimIndex The output dimension index of this node. 082 * @param dimName The output dimension name. 083 * @param numExamples The number of examples. 084 * @param featureIDMap The feature domain. 085 * @param outputInfo The output domain. 086 * @param leafDeterminer The leaf determination parameters. 087 */ 088 public RegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int dimIndex, String dimName, 089 int numExamples, ImmutableFeatureMap featureIDMap, 090 ImmutableOutputInfo<Regressor> outputInfo, LeafDeterminer leafDeterminer) { 091 this(impurity,tuple.copyData(),tuple.indices,tuple.targets[dimIndex],tuple.weights,dimName,numExamples,0, 092 featureIDMap,outputInfo, leafDeterminer); 093 } 094 095 private RegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, 096 float[] targets, float[] weights, String dimName, int numExamples, int depth, 097 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, 098 LeafDeterminer leafDeterminer) { 099 super(depth, numExamples, leafDeterminer); 100 this.data = data; 101 this.featureIDMap = featureIDMap; 102 this.labelIDMap = labelIDMap; 103 this.impurity = impurity; 104 this.indices = indices; 105 this.targets = targets; 106 this.weights = weights; 107 this.dimName = dimName; 108 this.weightSum = Util.sum(indices,indices.length,weights); 109 this.impurityScore = impurity.impurity(indices, targets, weights); 110 } 111 112 private RegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, 113 float[] targets, float[] weights, String dimName, int numExamples, int depth, 114 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, 115 LeafDeterminer leafDeterminer, float weightSum, double impurityScore) { 116 super(depth, numExamples, leafDeterminer); 117 this.data = data; 118 this.featureIDMap = featureIDMap; 119 this.labelIDMap = labelIDMap; 120 this.impurity = impurity; 121 this.indices = indices; 122 this.targets = targets; 123 this.weights = weights; 124 this.dimName = dimName; 125 this.weightSum = weightSum; 126 this.impurityScore = impurityScore; 127 } 128 129 @Override 130 public double getImpurity() { 131 return impurityScore; 132 } 133 134 @Override 135 public float getWeightSum() { 136 return weightSum; 137 } 138 139 /** 140 * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5). 141 * @param featureIDs Indices of the features available in this split. 142 * @param rng Splittable random number generator. 143 * @param useRandomSplitPoints Whether to choose split points for features at random. 144 * @return A possibly empty list of TrainingNodes. 145 */ 146 @Override 147 public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs, SplittableRandom rng, 148 boolean useRandomSplitPoints) { 149 if (useRandomSplitPoints) { 150 return buildRandomTree(featureIDs, rng); 151 } else { 152 return buildGreedyTree(featureIDs); 153 } 154 } 155 156 /** 157 * Builds a tree according to CART 158 * @param featureIDs Indices of the features available in this split. 159 * @return A possibly empty list of TrainingNodes. 160 */ 161 private List<AbstractTrainingNode<Regressor>> buildGreedyTree(int[] featureIDs) { 162 int bestID = -1; 163 double bestSplitValue = 0.0; 164 double bestScore = getImpurity(); 165 //logger.info("Cur node score = " + bestScore); 166 List<int[]> curIndices = new ArrayList<>(); 167 List<int[]> bestLeftIndices = new ArrayList<>(); 168 List<int[]> bestRightIndices = new ArrayList<>(); 169 for (int i = 0; i < featureIDs.length; i++) { 170 List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature(); 171 172 curIndices.clear(); 173 for (int j = 0; j < feature.size(); j++) { 174 InvertedFeature f = feature.get(j); 175 int[] curFeatureIndices = f.indices(); 176 curIndices.add(curFeatureIndices); 177 } 178 179 // searching for the intervals between features. 180 for (int j = 0; j < feature.size()-1; j++) { 181 List<int[]> curLeftIndices = curIndices.subList(0,j+1); 182 List<int[]> curRightIndices = curIndices.subList(j+1,feature.size()); 183 ImpurityTuple lessThanScore = impurity.impurityTuple(curLeftIndices,targets,weights); 184 ImpurityTuple greaterThanScore = impurity.impurityTuple(curRightIndices,targets,weights); 185 double score = (lessThanScore.impurity*lessThanScore.weight + greaterThanScore.impurity*greaterThanScore.weight) / weightSum; 186 if (score < bestScore) { 187 bestID = i; 188 bestScore = score; 189 bestSplitValue = (feature.get(j).value + feature.get(j + 1).value) / 2.0; 190 // Clear out the old best indices before storing the new ones. 191 bestLeftIndices.clear(); 192 bestLeftIndices.addAll(curLeftIndices); 193 bestRightIndices.clear(); 194 bestRightIndices.addAll(curRightIndices); 195 //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score); 196 //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size); 197 } 198 } 199 } 200 List<AbstractTrainingNode<Regressor>> output; 201 double impurityDecrease = weightSum * (getImpurity() - bestScore); 202 // If we found a split better than the current impurity. 203 if ((bestID != -1) && (impurityDecrease >= leafDeterminer.getScaledMinImpurityDecrease())) { 204 output = splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices); 205 } else { 206 output = Collections.emptyList(); 207 } 208 data = null; 209 return output; 210 } 211 212 /** 213 * Builds a CART tree with randomly chosen split points. 214 * @param featureIDs Indices of the features available in this split. 215 * @param rng Splittable random number generator. 216 * @return A possibly empty list of TrainingNodes. 217 */ 218 private List<AbstractTrainingNode<Regressor>> buildRandomTree(int[] featureIDs, SplittableRandom rng) { 219 int bestID = -1; 220 double bestSplitValue = 0.0; 221 double bestScore = getImpurity(); 222 //logger.info("Cur node score = " + bestScore); 223 List<int[]> curLeftIndices = new ArrayList<>(); 224 List<int[]> curRightIndices = new ArrayList<>(); 225 List<int[]> bestLeftIndices = new ArrayList<>(); 226 List<int[]> bestRightIndices = new ArrayList<>(); 227 228 // split each feature once randomly and record the least impure amongst these 229 for (int i = 0; i < featureIDs.length; i++) { 230 List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature(); 231 // if there is only 1 inverted feature for this feature, it has only 1 value, so cannot be split 232 if (feature.size() == 1) { 233 continue; 234 } 235 236 int splitIdx = rng.nextInt(feature.size()-1); 237 238 for (int j = 0; j < splitIdx + 1; j++) { 239 InvertedFeature vf; 240 vf = feature.get(j); 241 curLeftIndices.add(vf.indices()); 242 } 243 for (int j = splitIdx + 1; j < feature.size(); j++) { 244 InvertedFeature vf; 245 vf = feature.get(j); 246 curRightIndices.add(vf.indices()); 247 } 248 249 ImpurityTuple lessThanScore = impurity.impurityTuple(curLeftIndices,targets,weights); 250 ImpurityTuple greaterThanScore = impurity.impurityTuple(curRightIndices,targets,weights); 251 double score = (lessThanScore.impurity*lessThanScore.weight + greaterThanScore.impurity*greaterThanScore.weight) / weightSum; 252 if (score < bestScore) { 253 bestID = i; 254 bestScore = score; 255 bestSplitValue = (feature.get(splitIdx).value + feature.get(splitIdx + 1).value) / 2.0; 256 // Clear out the old best indices before storing the new ones. 257 bestLeftIndices.clear(); 258 bestLeftIndices.addAll(curLeftIndices); 259 bestRightIndices.clear(); 260 bestRightIndices.addAll(curRightIndices); 261 //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score); 262 //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size); 263 } 264 } 265 266 List<AbstractTrainingNode<Regressor>> output; 267 double impurityDecrease = weightSum * (getImpurity() - bestScore); 268 // If we found a split better than the current impurity. 269 if ((bestID != -1) && (impurityDecrease >= leafDeterminer.getScaledMinImpurityDecrease())) { 270 output = splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices); 271 } else { 272 output = Collections.emptyList(); 273 } 274 data = null; 275 return output; 276 } 277 278 /** 279 * Splits the data to form two nodes. 280 * @param featureIDs Indices of the features available in this split. 281 * @param bestID ID of the feature on which the split should be based. 282 * @param bestSplitValue Feature value to use for splitting the data. 283 * @param bestLeftIndices The indices of the examples less than or equal to the split value for the given feature. 284 * @param bestRightIndices The indices of the examples greater than the split value for the given feature. 285 * @return A list of training nodes resulting from the split. 286 */ 287 private List<AbstractTrainingNode<Regressor>> splitAtBest(int[] featureIDs, int bestID, double bestSplitValue, 288 List<int[]> bestLeftIndices, List<int[]> bestRightIndices){ 289 290 splitID = featureIDs[bestID]; 291 split = true; 292 splitValue = bestSplitValue; 293 IntArrayContainer firstBuffer = mergeBufferOne.get(); 294 firstBuffer.size = 0; 295 firstBuffer.grow(indices.length); 296 IntArrayContainer secondBuffer = mergeBufferTwo.get(); 297 secondBuffer.size = 0; 298 secondBuffer.grow(indices.length); 299 int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, firstBuffer, secondBuffer); 300 int[] rightIndices = IntArrayContainer.merge(bestRightIndices, firstBuffer, secondBuffer); 301 302 float leftWeightSum = Util.sum(leftIndices,leftIndices.length,weights); 303 double leftImpurityScore = impurity.impurity(leftIndices, targets, weights); 304 305 float rightWeightSum = Util.sum(rightIndices,rightIndices.length,weights); 306 double rightImpurityScore = impurity.impurity(rightIndices, targets, weights); 307 308 boolean shouldMakeLeftLeaf = shouldMakeLeaf(leftImpurityScore, leftWeightSum); 309 boolean shouldMakeRightLeaf = shouldMakeLeaf(rightImpurityScore, rightWeightSum); 310 311 if (shouldMakeLeftLeaf && shouldMakeRightLeaf) { 312 lessThanOrEqual = createLeaf(leftImpurityScore, leftIndices); 313 greaterThan = createLeaf(rightImpurityScore, rightIndices); 314 return Collections.emptyList(); 315 } 316 317 //logger.info("Splitting on feature " + bestID + " with value " + bestSplitValue + " at depth " + depth + ", " + numExamples + " examples in node."); 318 //logger.info("left indices length = " + leftIndices.length); 319 ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size()); 320 ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size()); 321 for (TreeFeature feature : data) { 322 Pair<TreeFeature,TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer); 323 lessThanData.add(split.getA()); 324 greaterThanData.add(split.getB()); 325 } 326 327 List<AbstractTrainingNode<Regressor>> output = new ArrayList<>(2); 328 AbstractTrainingNode<Regressor> tmpNode; 329 if (shouldMakeLeftLeaf) { 330 lessThanOrEqual = createLeaf(leftImpurityScore, leftIndices); 331 } else { 332 tmpNode = new RegressorTrainingNode(impurity, lessThanData, leftIndices, targets, weights, dimName, 333 leftIndices.length, depth + 1, featureIDMap, labelIDMap, leafDeterminer, leftWeightSum, leftImpurityScore); 334 lessThanOrEqual = tmpNode; 335 output.add(tmpNode); 336 } 337 338 if (shouldMakeRightLeaf) { 339 greaterThan = createLeaf(rightImpurityScore, rightIndices); 340 } else { 341 tmpNode = new RegressorTrainingNode(impurity, greaterThanData, rightIndices, targets, weights, dimName, 342 rightIndices.length, depth + 1, featureIDMap, labelIDMap, leafDeterminer, rightWeightSum, rightImpurityScore); 343 greaterThan = tmpNode; 344 output.add(tmpNode); 345 } 346 return output; 347 } 348 349 /** 350 * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node. 351 * @return A subtree using the SplitNode and LeafNode classes. 352 */ 353 @Override 354 public Node<Regressor> convertTree() { 355 if (split) { 356 return createSplitNode(); 357 } else { 358 return createLeaf(getImpurity(), indices); 359 } 360 } 361 362 /** 363 * Makes a {@link LeafNode} 364 * @param impurityScore the impurity score for the node. 365 * @param leafIndices the indices of the examples to be placed in the node. 366 * @return A {@link LeafNode} 367 */ 368 private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) { 369 double mean = 0.0; 370 double leafWeightSum = 0.0; 371 double variance = 0.0; 372 for (int i = 0; i < leafIndices.length; i++) { 373 int idx = leafIndices[i]; 374 float value = targets[idx]; 375 float weight = weights[idx]; 376 377 leafWeightSum += weight; 378 double oldMean = mean; 379 mean += (weight / leafWeightSum) * (value - oldMean); 380 variance += weight * (value - oldMean) * (value - mean); 381 } 382 variance = leafIndices.length > 1 ? variance / (leafWeightSum-1) : 0; 383 DimensionTuple leafPred = new DimensionTuple(dimName,mean,variance); 384 return new LeafNode<>(impurityScore,leafPred,Collections.emptyMap(),false); 385 } 386 387 /** 388 * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset 389 * so it's very expensive in terms of memory. 390 * @param examples An input dataset. 391 * @return A list of TreeFeatures which contain {@link InvertedFeature}s. 392 */ 393 public static InvertedData invertData(Dataset<Regressor> examples) { 394 ImmutableFeatureMap featureInfos = examples.getFeatureIDMap(); 395 ImmutableOutputInfo<Regressor> labelInfo = examples.getOutputIDInfo(); 396 int numLabels = labelInfo.size(); 397 int numFeatures = featureInfos.size(); 398 int[] indices = new int[examples.size()]; 399 float[][] targets = new float[numLabels][examples.size()]; 400 float[] weights = new float[examples.size()]; 401 402 logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs"); 403 ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size()); 404 405 for (int i = 0; i < featureInfos.size(); i++) { 406 data.add(new TreeFeature(i)); 407 } 408 409 int[] ids = ((ImmutableRegressionInfo) labelInfo).getNaturalOrderToIDMapping(); 410 for (int i = 0; i < examples.size(); i++) { 411 Example<Regressor> e = examples.getExample(i); 412 indices[i] = i; 413 weights[i] = e.getWeight(); 414 double[] output = e.getOutput().getValues(); 415 for (int j = 0; j < output.length; j++) { 416 targets[ids[j]][i] = (float) output[j]; 417 } 418 SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false); 419 int lastID = 0; 420 for (VectorTuple f : vec) { 421 int curID = f.index; 422 for (int j = lastID; j < curID; j++) { 423 data.get(j).observeValue(0.0,i); 424 } 425 data.get(curID).observeValue(f.value,i); 426 // 427 // These two checks should never occur as SparseVector deals with 428 // collisions, and Dataset prevents repeated features. 429 // They are left in just to make sure. 430 if (lastID > curID) { 431 logger.severe("Example = " + e.toString()); 432 throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 433 } else if (lastID-1 == curID) { 434 logger.severe("Example = " + e.toString()); 435 throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 436 } 437 lastID = curID + 1; 438 } 439 for (int j = lastID; j < numFeatures; j++) { 440 data.get(j).observeValue(0.0,i); 441 } 442 if (i % 1000 == 0) { 443 logger.fine("Processed example " + i); 444 } 445 } 446 447 logger.fine("Sorting features"); 448 449 data.forEach(TreeFeature::sort); 450 451 logger.fine("Fixing InvertedFeature sizes"); 452 453 data.forEach(TreeFeature::fixSize); 454 455 logger.fine("Built initial List<TreeFeature>"); 456 457 return new InvertedData(data,indices,targets,weights); 458 } 459 460 /** 461 * Tuple containing an inverted dataset (i.e., feature-wise not exmaple-wise). 462 */ 463 public static class InvertedData { 464 final ArrayList<TreeFeature> data; 465 final int[] indices; 466 final float[][] targets; 467 final float[] weights; 468 469 InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) { 470 this.data = data; 471 this.indices = indices; 472 this.targets = targets; 473 this.weights = weights; 474 } 475 476 ArrayList<TreeFeature> copyData() { 477 ArrayList<TreeFeature> newData = new ArrayList<>(); 478 479 for (TreeFeature f : data) { 480 newData.add(f.deepCopy()); 481 } 482 483 return newData; 484 } 485 } 486 487 private void writeObject(java.io.ObjectOutputStream stream) 488 throws IOException { 489 throw new NotSerializableException("RegressorTrainingNode is a runtime class only, and should not be serialized."); 490 } 491}