001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Prediction; 026import org.tribuo.common.tree.LeafNode; 027import org.tribuo.common.tree.Node; 028import org.tribuo.common.tree.SplitNode; 029import org.tribuo.common.tree.TreeModel; 030import org.tribuo.math.la.SparseVector; 031import org.tribuo.provenance.ModelProvenance; 032import org.tribuo.regression.Regressor; 033import org.tribuo.regression.Regressor.DimensionTuple; 034 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.Comparator; 038import java.util.HashMap; 039import java.util.HashSet; 040import java.util.LinkedHashSet; 041import java.util.LinkedList; 042import java.util.List; 043import java.util.Map; 044import java.util.Optional; 045import java.util.PriorityQueue; 046import java.util.Queue; 047import java.util.Set; 048 049/** 050 * A {@link Model} wrapped around a list of decision tree root {@link Node}s used 051 * to generate independent predictions for each dimension in a regression. 052 */ 053public final class IndependentRegressionTreeModel extends TreeModel<Regressor> { 054 private static final long serialVersionUID = 1L; 055 056 private final Map<String,Node<Regressor>> roots; 057 058 IndependentRegressionTreeModel(String name, ModelProvenance description, 059 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, boolean generatesProbabilities, 060 Map<String,Node<Regressor>> roots) { 061 super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, gatherActiveFeatures(featureIDMap,roots)); 062 this.roots = roots; 063 } 064 065 private static Map<String,List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Map<String,Node<Regressor>> roots) { 066 HashMap<String,List<String>> outputMap = new HashMap<>(); 067 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 068 Set<String> activeFeatures = new LinkedHashSet<>(); 069 070 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 071 072 nodeQueue.offer(e.getValue()); 073 074 while (!nodeQueue.isEmpty()) { 075 Node<Regressor> node = nodeQueue.poll(); 076 if ((node != null) && (!node.isLeaf())) { 077 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 078 String featureName = fMap.get(splitNode.getFeatureID()).getName(); 079 activeFeatures.add(featureName); 080 nodeQueue.offer(splitNode.getGreaterThan()); 081 nodeQueue.offer(splitNode.getLessThanOrEqual()); 082 } 083 } 084 outputMap.put(e.getKey(), new ArrayList<>(activeFeatures)); 085 } 086 return outputMap; 087 } 088 089 /** 090 * Probes the trees to find the depth. 091 * @return The maximum depth across the trees. 092 */ 093 @Override 094 public int getDepth() { 095 int maxDepth = 0; 096 for (Node<Regressor> curRoot : roots.values()) { 097 int thisDepth = computeDepth(0,curRoot); 098 if (maxDepth < thisDepth) { 099 maxDepth = thisDepth; 100 } 101 } 102 return maxDepth; 103 } 104 105 @Override 106 public Prediction<Regressor> predict(Example<Regressor> example) { 107 // 108 // Ensures we handle collisions correctly 109 SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false); 110 if (vec.numActiveElements() == 0) { 111 throw new IllegalArgumentException("No features found in Example " + example.toString()); 112 } 113 114 List<Prediction<Regressor>> predictionList = new ArrayList<>(); 115 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 116 Node<Regressor> oldNode = e.getValue(); 117 Node<Regressor> curNode = e.getValue(); 118 119 while (curNode != null) { 120 oldNode = curNode; 121 curNode = oldNode.getNextNode(vec); 122 } 123 124 // 125 // oldNode must be a LeafNode. 126 predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example)); 127 } 128 return combine(predictionList); 129 } 130 131 @Override 132 public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) { 133 int maxFeatures = n < 0 ? featureIDMap.size() : n; 134 135 Map<String, List<Pair<String, Double>>> map = new HashMap<>(); 136 Map<String, Integer> featureCounts = new HashMap<>(); 137 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 138 139 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 140 featureCounts.clear(); 141 nodeQueue.clear(); 142 143 nodeQueue.offer(e.getValue()); 144 145 while (!nodeQueue.isEmpty()) { 146 Node<Regressor> node = nodeQueue.poll(); 147 if ((node != null) && !node.isLeaf()) { 148 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 149 String featureName = featureIDMap.get(splitNode.getFeatureID()).getName(); 150 featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1); 151 nodeQueue.offer(splitNode.getGreaterThan()); 152 nodeQueue.offer(splitNode.getLessThanOrEqual()); 153 } 154 } 155 156 Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 157 PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator); 158 159 for (Map.Entry<String, Integer> featureCount : featureCounts.entrySet()) { 160 Pair<String, Double> cur = new Pair<>(featureCount.getKey(), (double) featureCount.getValue()); 161 if (q.size() < maxFeatures) { 162 q.offer(cur); 163 } else if (comparator.compare(cur, q.peek()) > 0) { 164 q.poll(); 165 q.offer(cur); 166 } 167 } 168 List<Pair<String, Double>> list = new ArrayList<>(); 169 while (q.size() > 0) { 170 list.add(q.poll()); 171 } 172 Collections.reverse(list); 173 174 map.put(e.getKey(), list); 175 } 176 177 return map; 178 } 179 180 @Override 181 public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) { 182 SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false); 183 if (vec.numActiveElements() == 0) { 184 return Optional.empty(); 185 } 186 187 List<String> list = new ArrayList<>(); 188 List<Prediction<Regressor>> predList = new ArrayList<>(); 189 Map<String, List<Pair<String, Double>>> map = new HashMap<>(); 190 191 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 192 list.clear(); 193 194 // 195 // Ensures we handle collisions correctly 196 Node<Regressor> oldNode = e.getValue(); 197 Node<Regressor> curNode = e.getValue(); 198 199 while (curNode != null) { 200 oldNode = curNode; 201 if (oldNode instanceof SplitNode) { 202 SplitNode<?> node = (SplitNode<?>) curNode; 203 list.add(featureIDMap.get(node.getFeatureID()).getName()); 204 } 205 curNode = oldNode.getNextNode(vec); 206 } 207 208 // 209 // oldNode must be a LeafNode. 210 predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example)); 211 212 List<Pair<String, Double>> pairs = new ArrayList<>(); 213 int i = list.size() + 1; 214 for (String s : list) { 215 pairs.add(new Pair<>(s, i + 0.0)); 216 i--; 217 } 218 219 map.put(e.getKey(), pairs); 220 } 221 Prediction<Regressor> combinedPrediction = combine(predList); 222 223 return Optional.of(new Excuse<>(example,combinedPrediction,map)); 224 } 225 226 @Override 227 protected IndependentRegressionTreeModel copy(String newName, ModelProvenance newProvenance) { 228 Map<String,Node<Regressor>> newRoots = new HashMap<>(); 229 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 230 newRoots.put(e.getKey(),e.getValue().copy()); 231 } 232 return new IndependentRegressionTreeModel(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,newRoots); 233 } 234 235 private Prediction<Regressor> combine(List<Prediction<Regressor>> predictions) { 236 DimensionTuple[] tuples = new DimensionTuple[predictions.size()]; 237 int numUsed = 0; 238 int i = 0; 239 for (Prediction<Regressor> p : predictions) { 240 if (numUsed < p.getNumActiveFeatures()) { 241 numUsed = p.getNumActiveFeatures(); 242 } 243 Regressor output = p.getOutput(); 244 if (output instanceof DimensionTuple) { 245 tuples[i] = (DimensionTuple)output; 246 } else { 247 throw new IllegalStateException("All the leaves should contain DimensionTuple not Regressor"); 248 } 249 i++; 250 } 251 252 Example<Regressor> example = predictions.get(0).getExample(); 253 return new Prediction<>(new Regressor(tuples),numUsed,example); 254 } 255 256 @Override 257 public Set<String> getFeatures() { 258 Set<String> features = new HashSet<>(); 259 260 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 261 262 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 263 nodeQueue.offer(e.getValue()); 264 265 while (!nodeQueue.isEmpty()) { 266 Node<Regressor> node = nodeQueue.poll(); 267 if ((node != null) && !node.isLeaf()) { 268 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 269 features.add(featureIDMap.get(splitNode.getFeatureID()).getName()); 270 nodeQueue.offer(splitNode.getGreaterThan()); 271 nodeQueue.offer(splitNode.getLessThanOrEqual()); 272 } 273 } 274 } 275 276 return features; 277 } 278 279 @Override 280 public String toString() { 281 StringBuilder sb = new StringBuilder(); 282 for (Map.Entry<String,Node<Regressor>> curRoot : roots.entrySet()) { 283 sb.append("Output '"); 284 sb.append(curRoot.getKey()); 285 sb.append("' - tree = "); 286 sb.append(curRoot.getValue().toString()); 287 sb.append('\n'); 288 } 289 return "IndependentTreeModel(description="+provenance.toString()+",\n"+sb.toString()+")"; 290 } 291 292 /** 293 * Returns an unmodifiable view on the root node collection. 294 * <p> 295 * The nodes themselves are immutable. 296 * @return The root node collection. 297 */ 298 public Map<String,Node<Regressor>> getRoots() { 299 return Collections.unmodifiableMap(roots); 300 } 301 302 /** 303 * Returns null, as this model contains multiple roots, one per regression output dimension. 304 * <p> 305 * Use {@link #getRoots()} instead. 306 * @return null. 307 */ 308 @Override 309 public Node<Regressor> getRoot() { 310 return null; 311 } 312}