001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impl; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.common.tree.AbstractTrainingNode; 025import org.tribuo.common.tree.LeafNode; 026import org.tribuo.common.tree.Node; 027import org.tribuo.common.tree.SplitNode; 028import org.tribuo.common.tree.impl.IntArrayContainer; 029import org.tribuo.math.la.SparseVector; 030import org.tribuo.math.la.VectorTuple; 031import org.tribuo.regression.ImmutableRegressionInfo; 032import org.tribuo.regression.Regressor; 033import org.tribuo.regression.rtree.impurity.RegressorImpurity; 034import org.tribuo.regression.rtree.impurity.RegressorImpurity.ImpurityTuple; 035import org.tribuo.util.Util; 036 037import java.io.IOException; 038import java.io.NotSerializableException; 039import java.util.ArrayList; 040import java.util.Collections; 041import java.util.List; 042import java.util.SplittableRandom; 043import java.util.logging.Logger; 044 045/** 046 * A decision tree node used at training time. 047 * Contains a list of the example indices currently found in this node, 048 * the current impurity and a bunch of other statistics. 049 */ 050public class JointRegressorTrainingNode extends AbstractTrainingNode<Regressor> { 051 private static final long serialVersionUID = 1L; 052 053 private static final Logger logger = Logger.getLogger(JointRegressorTrainingNode.class.getName()); 054 055 private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 056 private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 057 058 private transient ArrayList<TreeFeature> data; 059 060 private final boolean normalize; 061 062 private final ImmutableOutputInfo<Regressor> labelIDMap; 063 064 private final ImmutableFeatureMap featureIDMap; 065 066 private final RegressorImpurity impurity; 067 068 private final int[] indices; 069 070 private final float[][] targets; 071 072 private final float[] weights; 073 074 private final float weightSum; 075 076 /** 077 * Constructor which creates the inverted file. 078 * @param impurity The impurity function to use. 079 * @param examples The training data. 080 * @param normalize Normalizes the leaves so each leaf has a distribution which sums to 1.0. 081 * @param leafDeterminer Contains parameters needed to determine whether a node is a leaf. 082 */ 083 public JointRegressorTrainingNode(RegressorImpurity impurity, Dataset<Regressor> examples, boolean normalize, 084 LeafDeterminer leafDeterminer) { 085 this(impurity, invertData(examples), examples.size(), examples.getFeatureIDMap(), examples.getOutputIDInfo(), 086 normalize, leafDeterminer); 087 } 088 089 private JointRegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int numExamples, 090 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo, 091 boolean normalize, LeafDeterminer leafDeterminer) { 092 this(impurity,tuple.data,tuple.indices,tuple.targets,tuple.weights,numExamples,0,featureIDMap,outputInfo, 093 normalize, leafDeterminer); 094 } 095 096 private JointRegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, 097 float[][] targets, float[] weights, int numExamples, int depth, 098 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, 099 boolean normalize, LeafDeterminer leafDeterminer) { 100 super(depth, numExamples, leafDeterminer); 101 this.data = data; 102 this.normalize = normalize; 103 this.featureIDMap = featureIDMap; 104 this.labelIDMap = labelIDMap; 105 this.impurity = impurity; 106 this.indices = indices; 107 this.targets = targets; 108 this.weights = weights; 109 this.weightSum = Util.sum(indices,indices.length,weights); 110 this.impurityScore = calcImpurity(indices); 111 } 112 113 private JointRegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, 114 float[][] targets, float[] weights, int numExamples, int depth, 115 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, 116 boolean normalize, LeafDeterminer leafDeterminer, float weightSum, 117 double impurityScore) { 118 super(depth, numExamples, leafDeterminer); 119 this.data = data; 120 this.normalize = normalize; 121 this.featureIDMap = featureIDMap; 122 this.labelIDMap = labelIDMap; 123 this.impurity = impurity; 124 this.indices = indices; 125 this.targets = targets; 126 this.weights = weights; 127 this.weightSum = weightSum; 128 this.impurityScore = impurityScore; 129 } 130 131 @Override 132 public double getImpurity() { 133 return impurityScore; 134 } 135 136 @Override 137 public float getWeightSum() { 138 return weightSum; 139 } 140 141 /** 142 * Calculates the impurity score of the node. 143 * @return The impurity score of the node. 144 */ 145 private double calcImpurity(int[] curIndices) { 146 double tmp = 0.0; 147 for (int i = 0; i < targets.length; i++) { 148 tmp += impurity.impurity(curIndices, targets[i], weights); 149 } 150 return tmp / targets.length; 151 } 152 153 /** 154 * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5). 155 * @param featureIDs Indices of the features available in this split. 156 * @param rng Splittable random number generator. 157 * @param useRandomSplitPoints Whether to choose split points for features at random. 158 * @return A possibly empty list of TrainingNodes. 159 */ 160 @Override 161 public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs, SplittableRandom rng, 162 boolean useRandomSplitPoints) { 163 if (useRandomSplitPoints) { 164 return buildRandomTree(featureIDs, rng); 165 } else { 166 return buildGreedyTree(featureIDs); 167 } 168 } 169 170 /** 171 * Builds a tree according to CART 172 * @param featureIDs Indices of the features available in this split. 173 * @return A possibly empty list of TrainingNodes. 174 */ 175 private List<AbstractTrainingNode<Regressor>> buildGreedyTree(int[] featureIDs) { 176 int bestID = -1; 177 double bestSplitValue = 0.0; 178 double bestScore = getImpurity(); 179 //logger.info("Cur node score = " + bestScore); 180 List<int[]> curIndices = new ArrayList<>(); 181 List<int[]> bestLeftIndices = new ArrayList<>(); 182 List<int[]> bestRightIndices = new ArrayList<>(); 183 for (int i = 0; i < featureIDs.length; i++) { 184 List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature(); 185 186 curIndices.clear(); 187 for (int j = 0; j < feature.size(); j++) { 188 InvertedFeature f = feature.get(j); 189 int[] curFeatureIndices = f.indices(); 190 curIndices.add(curFeatureIndices); 191 } 192 193 // searching for the intervals between features. 194 for (int j = 0; j < feature.size()-1; j++) { 195 List<int[]> curLeftIndices = curIndices.subList(0,j+1); 196 List<int[]> curRightIndices = curIndices.subList(j+1,feature.size()); 197 double lessThanScore = 0.0; 198 double greaterThanScore = 0.0; 199 for (int k = 0; k < targets.length; k++) { 200 ImpurityTuple left = impurity.impurityTuple(curLeftIndices,targets[k],weights); 201 lessThanScore += left.impurity * left.weight; 202 ImpurityTuple right = impurity.impurityTuple(curRightIndices,targets[k],weights); 203 greaterThanScore += right.impurity * right.weight; 204 } 205 double score = (lessThanScore + greaterThanScore) / (targets.length * weightSum); 206 if (score < bestScore) { 207 bestID = i; 208 bestScore = score; 209 bestSplitValue = (feature.get(j).value + feature.get(j + 1).value) / 2.0; 210 // Clear out the old best indices before storing the new ones. 211 bestLeftIndices.clear(); 212 bestLeftIndices.addAll(curLeftIndices); 213 bestRightIndices.clear(); 214 bestRightIndices.addAll(curRightIndices); 215 //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score); 216 //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size); 217 } 218 } 219 } 220 List<AbstractTrainingNode<Regressor>> output; 221 double impurityDecrease = weightSum * (getImpurity() - bestScore); 222 // If we found a split better than the current impurity. 223 if ((bestID != -1) && (impurityDecrease >= leafDeterminer.getScaledMinImpurityDecrease())) { 224 output = splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices); 225 } else { 226 output = Collections.emptyList(); 227 } 228 data = null; 229 return output; 230 } 231 232 /** 233 * Builds a CART tree with randomly chosen split points. 234 * @param featureIDs Indices of the features available in this split. 235 * @param rng Splittable random number generator. 236 * @return A possibly empty list of TrainingNodes. 237 */ 238 private List<AbstractTrainingNode<Regressor>> buildRandomTree(int[] featureIDs, SplittableRandom rng) { 239 int bestID = -1; 240 double bestSplitValue = 0.0; 241 double bestScore = getImpurity(); 242 //logger.info("Cur node score = " + bestScore); 243 List<int[]> curLeftIndices = new ArrayList<>(); 244 List<int[]> curRightIndices = new ArrayList<>(); 245 List<int[]> bestLeftIndices = new ArrayList<>(); 246 List<int[]> bestRightIndices = new ArrayList<>(); 247 248 // split each feature once randomly and record the least impure amongst these 249 for (int i = 0; i < featureIDs.length; i++) { 250 List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature(); 251 // if there is only 1 inverted feature for this feature, it has only 1 value, so cannot be split 252 if (feature.size() == 1) { 253 continue; 254 } 255 256 double lessThanScore = 0.0; 257 double greaterThanScore = 0.0; 258 259 int splitIdx = rng.nextInt(feature.size()-1); 260 261 for (int j = 0; j < splitIdx + 1; j++) { 262 InvertedFeature vf; 263 vf = feature.get(j); 264 curLeftIndices.add(vf.indices()); 265 } 266 for (int j = splitIdx + 1; j < feature.size(); j++) { 267 InvertedFeature vf; 268 vf = feature.get(j); 269 curRightIndices.add(vf.indices()); 270 } 271 272 for (int k = 0; k < targets.length; k++) { 273 ImpurityTuple left = impurity.impurityTuple(curLeftIndices,targets[k],weights); 274 lessThanScore += left.impurity * left.weight; 275 ImpurityTuple right = impurity.impurityTuple(curRightIndices,targets[k],weights); 276 greaterThanScore += right.impurity * right.weight; 277 } 278 279 double score = (lessThanScore + greaterThanScore) / (targets.length * weightSum); 280 if (score < bestScore) { 281 bestID = i; 282 bestScore = score; 283 bestSplitValue = (feature.get(splitIdx).value + feature.get(splitIdx + 1).value) / 2.0; 284 // Clear out the old best indices before storing the new ones. 285 bestLeftIndices.clear(); 286 bestLeftIndices.addAll(curLeftIndices); 287 bestRightIndices.clear(); 288 bestRightIndices.addAll(curRightIndices); 289 //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score); 290 //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size); 291 } 292 } 293 294 List<AbstractTrainingNode<Regressor>> output; 295 double impurityDecrease = weightSum * (getImpurity() - bestScore); 296 // If we found a split better than the current impurity. 297 if ((bestID != -1) && (impurityDecrease >= leafDeterminer.getScaledMinImpurityDecrease())) { 298 output = splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices); 299 } else { 300 output = Collections.emptyList(); 301 } 302 data = null; 303 return output; 304 } 305 306 /** 307 * Splits the data to form two nodes. 308 * @param featureIDs Indices of the features available in this split. 309 * @param bestID ID of the feature on which the split should be based. 310 * @param bestSplitValue Feature value to use for splitting the data. 311 * @param bestLeftIndices The indices of the examples less than or equal to the split value for the given feature. 312 * @param bestRightIndices The indices of the examples greater than the split value for the given feature. 313 * @return A list of training nodes resulting from the split. 314 */ 315 private List<AbstractTrainingNode<Regressor>> splitAtBest(int[] featureIDs, int bestID, double bestSplitValue, 316 List<int[]> bestLeftIndices, List<int[]> bestRightIndices){ 317 splitID = featureIDs[bestID]; 318 split = true; 319 splitValue = bestSplitValue; 320 IntArrayContainer firstBuffer = mergeBufferOne.get(); 321 firstBuffer.size = 0; 322 firstBuffer.grow(indices.length); 323 IntArrayContainer secondBuffer = mergeBufferTwo.get(); 324 secondBuffer.size = 0; 325 secondBuffer.grow(indices.length); 326 int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, firstBuffer, secondBuffer); 327 int[] rightIndices = IntArrayContainer.merge(bestRightIndices, firstBuffer, secondBuffer); 328 329 float leftWeightSum = Util.sum(leftIndices,leftIndices.length,weights); 330 double leftImpurityScore = calcImpurity(leftIndices); 331 332 float rightWeightSum = Util.sum(rightIndices,rightIndices.length,weights); 333 double rightImpurityScore = calcImpurity(rightIndices); 334 335 boolean shouldMakeLeftLeaf = shouldMakeLeaf(leftImpurityScore, leftWeightSum); 336 boolean shouldMakeRightLeaf = shouldMakeLeaf(rightImpurityScore, rightWeightSum); 337 338 if (shouldMakeLeftLeaf && shouldMakeRightLeaf) { 339 lessThanOrEqual = createLeaf(leftImpurityScore, leftIndices); 340 greaterThan = createLeaf(rightImpurityScore, rightIndices); 341 return Collections.emptyList(); 342 } 343 //logger.info("Splitting on feature " + bestID + " with value " + bestSplitValue + " at depth " + depth + ", " + numExamples + " examples in node."); 344 //logger.info("left indices length = " + leftIndices.length); 345 ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size()); 346 ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size()); 347 for (TreeFeature feature : data) { 348 Pair<TreeFeature,TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer); 349 lessThanData.add(split.getA()); 350 greaterThanData.add(split.getB()); 351 } 352 353 List<AbstractTrainingNode<Regressor>> output = new ArrayList<>(2); 354 AbstractTrainingNode<Regressor> tmpNode; 355 if (shouldMakeLeftLeaf) { 356 lessThanOrEqual = createLeaf(leftImpurityScore, leftIndices); 357 } else { 358 tmpNode = new JointRegressorTrainingNode(impurity, lessThanData, leftIndices, targets, 359 weights, leftIndices.length, depth + 1, featureIDMap, labelIDMap, normalize, leafDeterminer, 360 leftWeightSum, leftImpurityScore); 361 lessThanOrEqual = tmpNode; 362 output.add(tmpNode); 363 } 364 365 if (shouldMakeRightLeaf) { 366 greaterThan = createLeaf(rightImpurityScore, rightIndices); 367 } else { 368 tmpNode = new JointRegressorTrainingNode(impurity, greaterThanData, rightIndices, targets, 369 weights, rightIndices.length, depth + 1, featureIDMap, labelIDMap, normalize, leafDeterminer, 370 rightWeightSum, rightImpurityScore); 371 greaterThan = tmpNode; 372 output.add(tmpNode); 373 } 374 return output; 375 } 376 377 /** 378 * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node. 379 * @return A subtree using the SplitNode and LeafNode classes. 380 */ 381 @Override 382 public Node<Regressor> convertTree() { 383 if (split) { 384 return createSplitNode(); 385 } else { 386 return createLeaf(getImpurity(), indices); 387 } 388 } 389 390 /** 391 * Makes a {@link LeafNode} 392 * @param impurityScore the impurity score for the node. 393 * @param leafIndices the indices of the examples to be placed in the node. 394 * @return A {@link LeafNode} 395 */ 396 private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) { 397 double leafWeightSum = 0.0; 398 double[] mean = new double[targets.length]; 399 Regressor leafPred; 400 if (normalize) { 401 for (int i = 0; i < leafIndices.length; i++) { 402 int idx = leafIndices[i]; 403 float weight = weights[idx]; 404 leafWeightSum += weight; 405 for (int j = 0; j < targets.length; j++) { 406 float value = targets[j][idx]; 407 408 double oldMean = mean[j]; 409 mean[j] += (weight / leafWeightSum) * (value - oldMean); 410 } 411 } 412 String[] names = new String[targets.length]; 413 double sum = 0.0; 414 for (int i = 0; i < targets.length; i++) { 415 names[i] = labelIDMap.getOutput(i).getNames()[0]; 416 sum += mean[i]; 417 } 418 // Normalize all the outputs so that they sum to 1.0. 419 for (int i = 0; i < targets.length; i++) { 420 mean[i] /= sum; 421 } 422 // Both names and mean are in id order, so the regressor constructor 423 // will convert them to natural order if they are different. 424 leafPred = new Regressor(names, mean); 425 } else { 426 double[] variance = new double[targets.length]; 427 for (int i = 0; i < leafIndices.length; i++) { 428 int idx = leafIndices[i]; 429 float weight = weights[idx]; 430 leafWeightSum += weight; 431 for (int j = 0; j < targets.length; j++) { 432 float value = targets[j][idx]; 433 434 double oldMean = mean[j]; 435 mean[j] += (weight / leafWeightSum) * (value - oldMean); 436 variance[j] += weight * (value - oldMean) * (value - mean[j]); 437 } 438 } 439 String[] names = new String[targets.length]; 440 for (int i = 0; i < targets.length; i++) { 441 names[i] = labelIDMap.getOutput(i).getNames()[0]; 442 variance[i] = leafIndices.length > 1 ? variance[i] / (leafWeightSum - 1) : 0; 443 } 444 // Both names, mean and variance are in id order, so the regressor constructor 445 // will convert them to natural order if they are different. 446 leafPred = new Regressor(names, mean, variance); 447 } 448 return new LeafNode<>(impurityScore,leafPred,Collections.emptyMap(),false); 449 } 450 451 /** 452 * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset 453 * so it's very expensive in terms of memory. 454 * @param examples An input dataset. 455 * @return A list of TreeFeatures which contain {@link InvertedFeature}s. 456 */ 457 private static InvertedData invertData(Dataset<Regressor> examples) { 458 ImmutableFeatureMap featureInfos = examples.getFeatureIDMap(); 459 ImmutableOutputInfo<Regressor> labelInfo = examples.getOutputIDInfo(); 460 int numLabels = labelInfo.size(); 461 int numFeatures = featureInfos.size(); 462 int[] indices = new int[examples.size()]; 463 float[][] targets = new float[labelInfo.size()][examples.size()]; 464 float[] weights = new float[examples.size()]; 465 466 logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs"); 467 ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size()); 468 469 for (int i = 0; i < featureInfos.size(); i++) { 470 data.add(new TreeFeature(i)); 471 } 472 473 int[] ids = ((ImmutableRegressionInfo) labelInfo).getNaturalOrderToIDMapping(); 474 for (int i = 0; i < examples.size(); i++) { 475 Example<Regressor> e = examples.getExample(i); 476 indices[i] = i; 477 weights[i] = e.getWeight(); 478 double[] output = e.getOutput().getValues(); 479 for (int j = 0; j < targets.length; j++) { 480 targets[ids[j]][i] = (float) output[j]; 481 } 482 SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false); 483 int lastID = 0; 484 for (VectorTuple f : vec) { 485 int curID = f.index; 486 for (int j = lastID; j < curID; j++) { 487 data.get(j).observeValue(0.0,i); 488 } 489 data.get(curID).observeValue(f.value,i); 490 // 491 // These two checks should never occur as SparseVector deals with 492 // collisions, and Dataset prevents repeated features. 493 // They are left in just to make sure. 494 if (lastID > curID) { 495 logger.severe("Example = " + e.toString()); 496 throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 497 } else if (lastID-1 == curID) { 498 logger.severe("Example = " + e.toString()); 499 throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 500 } 501 lastID = curID + 1; 502 } 503 for (int j = lastID; j < numFeatures; j++) { 504 data.get(j).observeValue(0.0,i); 505 } 506 if (i % 1000 == 0) { 507 logger.fine("Processed example " + i); 508 } 509 } 510 511 logger.fine("Sorting features"); 512 513 data.forEach(TreeFeature::sort); 514 515 logger.fine("Fixing InvertedFeature sizes"); 516 517 data.forEach(TreeFeature::fixSize); 518 519 logger.fine("Built initial List<TreeFeature>"); 520 521 return new InvertedData(data,indices,targets,weights); 522 } 523 524 private static class InvertedData { 525 final ArrayList<TreeFeature> data; 526 final int[] indices; 527 final float[][] targets; 528 final float[] weights; 529 530 InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) { 531 this.data = data; 532 this.indices = indices; 533 this.targets = targets; 534 this.weights = weights; 535 } 536 } 537 538 private void writeObject(java.io.ObjectOutputStream stream) 539 throws IOException { 540 throw new NotSerializableException("JointRegressorTrainingNode is a runtime class only, and should not be serialized."); 541 } 542}