001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.common.tree.LeafNode;
027import org.tribuo.common.tree.Node;
028import org.tribuo.common.tree.SplitNode;
029import org.tribuo.common.tree.TreeModel;
030import org.tribuo.math.la.SparseVector;
031import org.tribuo.provenance.ModelProvenance;
032import org.tribuo.regression.Regressor;
033import org.tribuo.regression.Regressor.DimensionTuple;
034
035import java.util.ArrayList;
036import java.util.Collections;
037import java.util.Comparator;
038import java.util.HashMap;
039import java.util.HashSet;
040import java.util.LinkedHashSet;
041import java.util.LinkedList;
042import java.util.List;
043import java.util.Map;
044import java.util.Optional;
045import java.util.PriorityQueue;
046import java.util.Queue;
047import java.util.Set;
048
049/**
050 * A {@link Model} wrapped around a list of decision tree root {@link Node}s used
051 * to generate independent predictions for each dimension in a regression.
052 */
053public final class IndependentRegressionTreeModel extends TreeModel<Regressor> {
054    private static final long serialVersionUID = 1L;
055
056    private final Map<String,Node<Regressor>> roots;
057
058    IndependentRegressionTreeModel(String name, ModelProvenance description,
059                                          ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, boolean generatesProbabilities,
060                                          Map<String,Node<Regressor>> roots) {
061        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, gatherActiveFeatures(featureIDMap,roots));
062        this.roots = roots;
063    }
064
065    private static Map<String,List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Map<String,Node<Regressor>> roots) {
066        HashMap<String,List<String>> outputMap = new HashMap<>();
067        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
068            Set<String> activeFeatures = new LinkedHashSet<>();
069
070            Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
071
072            nodeQueue.offer(e.getValue());
073
074            while (!nodeQueue.isEmpty()) {
075                Node<Regressor> node = nodeQueue.poll();
076                if ((node != null) && (!node.isLeaf())) {
077                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
078                    String featureName = fMap.get(splitNode.getFeatureID()).getName();
079                    activeFeatures.add(featureName);
080                    nodeQueue.offer(splitNode.getGreaterThan());
081                    nodeQueue.offer(splitNode.getLessThanOrEqual());
082                }
083            }
084            outputMap.put(e.getKey(), new ArrayList<>(activeFeatures));
085        }
086        return outputMap;
087    }
088
089    /**
090     * Probes the trees to find the depth.
091     * @return The maximum depth across the trees.
092     */
093    @Override
094    public int getDepth() {
095        int maxDepth = 0;
096        for (Node<Regressor> curRoot : roots.values()) {
097            int thisDepth = computeDepth(0,curRoot);
098            if (maxDepth < thisDepth) {
099                maxDepth = thisDepth;
100            }
101        }
102        return maxDepth;
103    }
104
105    @Override
106    public Prediction<Regressor> predict(Example<Regressor> example) {
107        //
108        // Ensures we handle collisions correctly
109        SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false);
110        if (vec.numActiveElements() == 0) {
111            throw new IllegalArgumentException("No features found in Example " + example.toString());
112        }
113
114        List<Prediction<Regressor>> predictionList = new ArrayList<>();
115        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
116            Node<Regressor> oldNode = e.getValue();
117            Node<Regressor> curNode = e.getValue();
118
119            while (curNode != null) {
120                oldNode = curNode;
121                curNode = oldNode.getNextNode(vec);
122            }
123
124            //
125            // oldNode must be a LeafNode.
126            predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
127        }
128        return combine(predictionList);
129    }
130
131    @Override
132    public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) {
133        int maxFeatures = n < 0 ? featureIDMap.size() : n;
134
135        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
136        Map<String, Integer> featureCounts = new HashMap<>();
137        Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
138
139        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
140            featureCounts.clear();
141            nodeQueue.clear();
142
143            nodeQueue.offer(e.getValue());
144
145            while (!nodeQueue.isEmpty()) {
146                Node<Regressor> node = nodeQueue.poll();
147                if ((node != null) && !node.isLeaf()) {
148                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
149                    String featureName = featureIDMap.get(splitNode.getFeatureID()).getName();
150                    featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1);
151                    nodeQueue.offer(splitNode.getGreaterThan());
152                    nodeQueue.offer(splitNode.getLessThanOrEqual());
153                }
154            }
155
156            Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
157            PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator);
158
159            for (Map.Entry<String, Integer> featureCount : featureCounts.entrySet()) {
160                Pair<String, Double> cur = new Pair<>(featureCount.getKey(), (double) featureCount.getValue());
161                if (q.size() < maxFeatures) {
162                    q.offer(cur);
163                } else if (comparator.compare(cur, q.peek()) > 0) {
164                    q.poll();
165                    q.offer(cur);
166                }
167            }
168            List<Pair<String, Double>> list = new ArrayList<>();
169            while (q.size() > 0) {
170                list.add(q.poll());
171            }
172            Collections.reverse(list);
173
174            map.put(e.getKey(), list);
175        }
176
177        return map;
178    }
179
180    @Override
181    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
182        SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
183        if (vec.numActiveElements() == 0) {
184            return Optional.empty();
185        }
186
187        List<String> list = new ArrayList<>();
188        List<Prediction<Regressor>> predList = new ArrayList<>();
189        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
190
191        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
192            list.clear();
193
194            //
195            // Ensures we handle collisions correctly
196            Node<Regressor> oldNode = e.getValue();
197            Node<Regressor> curNode = e.getValue();
198
199            while (curNode != null) {
200                oldNode = curNode;
201                if (oldNode instanceof SplitNode) {
202                    SplitNode<?> node = (SplitNode<?>) curNode;
203                    list.add(featureIDMap.get(node.getFeatureID()).getName());
204                }
205                curNode = oldNode.getNextNode(vec);
206            }
207
208            //
209            // oldNode must be a LeafNode.
210            predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
211
212            List<Pair<String, Double>> pairs = new ArrayList<>();
213            int i = list.size() + 1;
214            for (String s : list) {
215                pairs.add(new Pair<>(s, i + 0.0));
216                i--;
217            }
218
219            map.put(e.getKey(), pairs);
220        }
221        Prediction<Regressor> combinedPrediction = combine(predList);
222
223        return Optional.of(new Excuse<>(example,combinedPrediction,map));
224    }
225
226    @Override
227    protected IndependentRegressionTreeModel copy(String newName, ModelProvenance newProvenance) {
228        Map<String,Node<Regressor>> newRoots = new HashMap<>();
229        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
230            newRoots.put(e.getKey(),e.getValue().copy());
231        }
232        return new IndependentRegressionTreeModel(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,newRoots);
233    }
234
235    private Prediction<Regressor> combine(List<Prediction<Regressor>> predictions) {
236        DimensionTuple[] tuples = new DimensionTuple[predictions.size()];
237        int numUsed = 0;
238        int i = 0;
239        for (Prediction<Regressor> p : predictions) {
240            if (numUsed < p.getNumActiveFeatures()) {
241                numUsed = p.getNumActiveFeatures();
242            }
243            Regressor output = p.getOutput();
244            if (output instanceof DimensionTuple) {
245                tuples[i] = (DimensionTuple)output;
246            } else {
247                throw new IllegalStateException("All the leaves should contain DimensionTuple not Regressor");
248            }
249            i++;
250        }
251
252        Example<Regressor> example = predictions.get(0).getExample();
253        return new Prediction<>(new Regressor(tuples),numUsed,example);
254    }
255
256    @Override
257    public Set<String> getFeatures() {
258        Set<String> features = new HashSet<>();
259
260        Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
261
262        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
263            nodeQueue.offer(e.getValue());
264
265            while (!nodeQueue.isEmpty()) {
266                Node<Regressor> node = nodeQueue.poll();
267                if ((node != null) && !node.isLeaf()) {
268                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
269                    features.add(featureIDMap.get(splitNode.getFeatureID()).getName());
270                    nodeQueue.offer(splitNode.getGreaterThan());
271                    nodeQueue.offer(splitNode.getLessThanOrEqual());
272                }
273            }
274        }
275
276        return features;
277    }
278
279    @Override
280    public String toString() {
281        StringBuilder sb = new StringBuilder();
282        for (Map.Entry<String,Node<Regressor>> curRoot : roots.entrySet()) {
283            sb.append("Output '");
284            sb.append(curRoot.getKey());
285            sb.append("' - tree = ");
286            sb.append(curRoot.getValue().toString());
287            sb.append('\n');
288        }
289        return "IndependentTreeModel(description="+provenance.toString()+",\n"+sb.toString()+")";
290    }
291
292    /**
293     * Returns an unmodifiable view on the root node collection.
294     * <p>
295     * The nodes themselves are immutable.
296     * @return The root node collection.
297     */
298    public Map<String,Node<Regressor>> getRoots() {
299        return Collections.unmodifiableMap(roots);
300    }
301
302    /**
303     * Returns null, as this model contains multiple roots, one per regression output dimension.
304     * <p>
305     * Use {@link #getRoots()} instead.
306     * @return null.
307     */
308    @Override
309    public Node<Regressor> getRoot() {
310        return null;
311    }
312}