001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree.impl;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.common.tree.AbstractTrainingNode;
025import org.tribuo.common.tree.LeafNode;
026import org.tribuo.common.tree.Node;
027import org.tribuo.common.tree.SplitNode;
028import org.tribuo.common.tree.impl.IntArrayContainer;
029import org.tribuo.math.la.SparseVector;
030import org.tribuo.math.la.VectorTuple;
031import org.tribuo.regression.ImmutableRegressionInfo;
032import org.tribuo.regression.Regressor;
033import org.tribuo.regression.Regressor.DimensionTuple;
034import org.tribuo.regression.rtree.impurity.RegressorImpurity;
035import org.tribuo.regression.rtree.impurity.RegressorImpurity.ImpurityTuple;
036import org.tribuo.util.Util;
037
038import java.io.IOException;
039import java.io.NotSerializableException;
040import java.util.ArrayList;
041import java.util.Collections;
042import java.util.List;
043import java.util.SplittableRandom;
044import java.util.logging.Logger;
045
046/**
047 * A decision tree node used at training time.
048 * Contains a list of the example indices currently found in this node,
049 * the current impurity and a bunch of other statistics.
050 */
051public class RegressorTrainingNode extends AbstractTrainingNode<Regressor> {
052    private static final long serialVersionUID = 1L;
053
054    private static final Logger logger = Logger.getLogger(RegressorTrainingNode.class.getName());
055
056    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
057    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
058
059    private transient ArrayList<TreeFeature> data;
060
061    private final ImmutableOutputInfo<Regressor> labelIDMap;
062
063    private final ImmutableFeatureMap featureIDMap;
064
065    private final RegressorImpurity impurity;
066
067    private final int[] indices;
068
069    private final float[] targets;
070
071    private final float[] weights;
072
073    private final String dimName;
074
075    private final float weightSum;
076
077    /**
078     * Constructs a tree training node for regression problems.
079     * @param impurity The impurity function.
080     * @param tuple The data tuple.
081     * @param dimIndex The output dimension index of this node.
082     * @param dimName The output dimension name.
083     * @param numExamples The number of examples.
084     * @param featureIDMap The feature domain.
085     * @param outputInfo The output domain.
086     * @param leafDeterminer The leaf determination parameters.
087     */
088    public RegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int dimIndex, String dimName,
089                                 int numExamples, ImmutableFeatureMap featureIDMap,
090                                 ImmutableOutputInfo<Regressor> outputInfo, LeafDeterminer leafDeterminer) {
091        this(impurity,tuple.copyData(),tuple.indices,tuple.targets[dimIndex],tuple.weights,dimName,numExamples,0,
092                featureIDMap,outputInfo, leafDeterminer);
093    }
094
095    private RegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices,
096                                  float[] targets, float[] weights, String dimName, int numExamples, int depth,
097                                  ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap,
098                                  LeafDeterminer leafDeterminer) {
099        super(depth, numExamples, leafDeterminer);
100        this.data = data;
101        this.featureIDMap = featureIDMap;
102        this.labelIDMap = labelIDMap;
103        this.impurity = impurity;
104        this.indices = indices;
105        this.targets = targets;
106        this.weights = weights;
107        this.dimName = dimName;
108        this.weightSum = Util.sum(indices,indices.length,weights);
109        this.impurityScore = impurity.impurity(indices, targets, weights);
110    }
111
112    private RegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices,
113                                  float[] targets, float[] weights, String dimName, int numExamples, int depth,
114                                  ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap,
115                                  LeafDeterminer leafDeterminer, float weightSum, double impurityScore) {
116        super(depth, numExamples, leafDeterminer);
117        this.data = data;
118        this.featureIDMap = featureIDMap;
119        this.labelIDMap = labelIDMap;
120        this.impurity = impurity;
121        this.indices = indices;
122        this.targets = targets;
123        this.weights = weights;
124        this.dimName = dimName;
125        this.weightSum = weightSum;
126        this.impurityScore = impurityScore;
127    }
128
129    @Override
130    public double getImpurity() {
131        return impurityScore;
132    }
133
134    @Override
135    public float getWeightSum() {
136        return weightSum;
137    }
138
139    /**
140     * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5).
141     * @param featureIDs Indices of the features available in this split.
142     * @param rng Splittable random number generator.
143     * @param useRandomSplitPoints Whether to choose split points for features at random.
144     * @return A possibly empty list of TrainingNodes.
145     */
146    @Override
147    public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs, SplittableRandom rng,
148                                                           boolean useRandomSplitPoints) {
149        if (useRandomSplitPoints) {
150            return buildRandomTree(featureIDs, rng);
151        } else {
152            return buildGreedyTree(featureIDs);
153        }
154    }
155
156    /**
157     * Builds a tree according to CART
158     * @param featureIDs Indices of the features available in this split.
159     * @return A possibly empty list of TrainingNodes.
160     */
161    private List<AbstractTrainingNode<Regressor>> buildGreedyTree(int[] featureIDs) {
162        int bestID = -1;
163        double bestSplitValue = 0.0;
164        double bestScore = getImpurity();
165        //logger.info("Cur node score = " + bestScore);
166        List<int[]> curIndices = new ArrayList<>();
167        List<int[]> bestLeftIndices = new ArrayList<>();
168        List<int[]> bestRightIndices = new ArrayList<>();
169        for (int i = 0; i < featureIDs.length; i++) {
170            List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature();
171
172            curIndices.clear();
173            for (int j = 0; j < feature.size(); j++) {
174                InvertedFeature f = feature.get(j);
175                int[] curFeatureIndices = f.indices();
176                curIndices.add(curFeatureIndices);
177            }
178
179            // searching for the intervals between features.
180            for (int j = 0; j < feature.size()-1; j++) {
181                List<int[]> curLeftIndices = curIndices.subList(0,j+1);
182                List<int[]> curRightIndices = curIndices.subList(j+1,feature.size());
183                ImpurityTuple lessThanScore = impurity.impurityTuple(curLeftIndices,targets,weights);
184                ImpurityTuple greaterThanScore = impurity.impurityTuple(curRightIndices,targets,weights);
185                double score = (lessThanScore.impurity*lessThanScore.weight + greaterThanScore.impurity*greaterThanScore.weight) / weightSum;
186                if (score < bestScore) {
187                    bestID = i;
188                    bestScore = score;
189                    bestSplitValue = (feature.get(j).value + feature.get(j + 1).value) / 2.0;
190                    // Clear out the old best indices before storing the new ones.
191                    bestLeftIndices.clear();
192                    bestLeftIndices.addAll(curLeftIndices);
193                    bestRightIndices.clear();
194                    bestRightIndices.addAll(curRightIndices);
195                    //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score);
196                    //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size);
197                }
198            }
199        }
200        List<AbstractTrainingNode<Regressor>> output;
201        double impurityDecrease = weightSum * (getImpurity() - bestScore);
202        // If we found a split better than the current impurity.
203        if ((bestID != -1) && (impurityDecrease >= leafDeterminer.getScaledMinImpurityDecrease())) {
204            output = splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices);
205        } else {
206            output = Collections.emptyList();
207        }
208        data = null;
209        return output;
210    }
211
212    /**
213     * Builds a CART tree with randomly chosen split points.
214     * @param featureIDs Indices of the features available in this split.
215     * @param rng Splittable random number generator.
216     * @return A possibly empty list of TrainingNodes.
217     */
218    private List<AbstractTrainingNode<Regressor>> buildRandomTree(int[] featureIDs, SplittableRandom rng) {
219        int bestID = -1;
220        double bestSplitValue = 0.0;
221        double bestScore = getImpurity();
222        //logger.info("Cur node score = " + bestScore);
223        List<int[]> curLeftIndices = new ArrayList<>();
224        List<int[]> curRightIndices = new ArrayList<>();
225        List<int[]> bestLeftIndices = new ArrayList<>();
226        List<int[]> bestRightIndices = new ArrayList<>();
227
228        // split each feature once randomly and record the least impure amongst these
229        for (int i = 0; i < featureIDs.length; i++) {
230            List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature();
231            // if there is only 1 inverted feature for this feature, it has only 1 value, so cannot be split
232            if (feature.size() == 1) {
233                continue;
234            }
235
236            int splitIdx = rng.nextInt(feature.size()-1);
237
238            for (int j = 0; j < splitIdx + 1; j++) {
239                InvertedFeature vf;
240                vf = feature.get(j);
241                curLeftIndices.add(vf.indices());
242            }
243            for (int j = splitIdx + 1; j < feature.size(); j++) {
244                InvertedFeature vf;
245                vf = feature.get(j);
246                curRightIndices.add(vf.indices());
247            }
248
249            ImpurityTuple lessThanScore = impurity.impurityTuple(curLeftIndices,targets,weights);
250            ImpurityTuple greaterThanScore = impurity.impurityTuple(curRightIndices,targets,weights);
251            double score = (lessThanScore.impurity*lessThanScore.weight + greaterThanScore.impurity*greaterThanScore.weight) / weightSum;
252            if (score < bestScore) {
253                bestID = i;
254                bestScore = score;
255                bestSplitValue = (feature.get(splitIdx).value + feature.get(splitIdx + 1).value) / 2.0;
256                // Clear out the old best indices before storing the new ones.
257                bestLeftIndices.clear();
258                bestLeftIndices.addAll(curLeftIndices);
259                bestRightIndices.clear();
260                bestRightIndices.addAll(curRightIndices);
261                //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score);
262                //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size);
263            }
264        }
265
266        List<AbstractTrainingNode<Regressor>> output;
267        double impurityDecrease = weightSum * (getImpurity() - bestScore);
268        // If we found a split better than the current impurity.
269        if ((bestID != -1) && (impurityDecrease >= leafDeterminer.getScaledMinImpurityDecrease())) {
270            output = splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices);
271        } else {
272            output = Collections.emptyList();
273        }
274        data = null;
275        return output;
276    }
277
278    /**
279     * Splits the data to form two nodes.
280     * @param featureIDs Indices of the features available in this split.
281     * @param bestID ID of the feature on which the split should be based.
282     * @param bestSplitValue Feature value to use for splitting the data.
283     * @param bestLeftIndices The indices of the examples less than or equal to the split value for the given feature.
284     * @param bestRightIndices The indices of the examples greater than the split value for the given feature.
285     * @return A list of training nodes resulting from the split.
286     */
287    private List<AbstractTrainingNode<Regressor>> splitAtBest(int[] featureIDs, int bestID, double bestSplitValue,
288                                                             List<int[]> bestLeftIndices, List<int[]> bestRightIndices){
289
290        splitID = featureIDs[bestID];
291        split = true;
292        splitValue = bestSplitValue;
293        IntArrayContainer firstBuffer = mergeBufferOne.get();
294        firstBuffer.size = 0;
295        firstBuffer.grow(indices.length);
296        IntArrayContainer secondBuffer = mergeBufferTwo.get();
297        secondBuffer.size = 0;
298        secondBuffer.grow(indices.length);
299        int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, firstBuffer, secondBuffer);
300        int[] rightIndices = IntArrayContainer.merge(bestRightIndices, firstBuffer, secondBuffer);
301
302        float leftWeightSum = Util.sum(leftIndices,leftIndices.length,weights);
303        double leftImpurityScore = impurity.impurity(leftIndices, targets, weights);
304
305        float rightWeightSum = Util.sum(rightIndices,rightIndices.length,weights);
306        double rightImpurityScore = impurity.impurity(rightIndices, targets, weights);
307
308        boolean shouldMakeLeftLeaf = shouldMakeLeaf(leftImpurityScore, leftWeightSum);
309        boolean shouldMakeRightLeaf = shouldMakeLeaf(rightImpurityScore, rightWeightSum);
310
311        if (shouldMakeLeftLeaf && shouldMakeRightLeaf) {
312            lessThanOrEqual = createLeaf(leftImpurityScore, leftIndices);
313            greaterThan = createLeaf(rightImpurityScore, rightIndices);
314            return Collections.emptyList();
315        }
316
317        //logger.info("Splitting on feature " + bestID + " with value " + bestSplitValue + " at depth " + depth + ", " + numExamples + " examples in node.");
318        //logger.info("left indices length = " + leftIndices.length);
319        ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size());
320        ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size());
321        for (TreeFeature feature : data) {
322            Pair<TreeFeature,TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer);
323            lessThanData.add(split.getA());
324            greaterThanData.add(split.getB());
325        }
326
327        List<AbstractTrainingNode<Regressor>> output = new ArrayList<>(2);
328        AbstractTrainingNode<Regressor> tmpNode;
329        if (shouldMakeLeftLeaf) {
330            lessThanOrEqual = createLeaf(leftImpurityScore, leftIndices);
331        } else {
332            tmpNode = new RegressorTrainingNode(impurity, lessThanData, leftIndices, targets, weights, dimName,
333                    leftIndices.length, depth + 1, featureIDMap, labelIDMap, leafDeterminer, leftWeightSum, leftImpurityScore);
334            lessThanOrEqual = tmpNode;
335            output.add(tmpNode);
336        }
337
338        if (shouldMakeRightLeaf) {
339            greaterThan = createLeaf(rightImpurityScore, rightIndices);
340        } else {
341            tmpNode = new RegressorTrainingNode(impurity, greaterThanData, rightIndices, targets, weights, dimName,
342                    rightIndices.length, depth + 1, featureIDMap, labelIDMap, leafDeterminer, rightWeightSum, rightImpurityScore);
343            greaterThan = tmpNode;
344            output.add(tmpNode);
345        }
346        return output;
347    }
348
349    /**
350     * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node.
351     * @return A subtree using the SplitNode and LeafNode classes.
352     */
353    @Override
354    public Node<Regressor> convertTree() {
355        if (split) {
356            return createSplitNode();
357        } else {
358            return createLeaf(getImpurity(), indices);
359        }
360    }
361
362    /**
363     * Makes a {@link LeafNode}
364     * @param impurityScore the impurity score for the node.
365     * @param leafIndices the indices of the examples to be placed in the node.
366     * @return A {@link LeafNode}
367     */
368    private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) {
369        double mean = 0.0;
370        double leafWeightSum = 0.0;
371        double variance = 0.0;
372        for (int i = 0; i < leafIndices.length; i++) {
373            int idx = leafIndices[i];
374            float value = targets[idx];
375            float weight = weights[idx];
376
377            leafWeightSum += weight;
378            double oldMean = mean;
379            mean += (weight / leafWeightSum) * (value - oldMean);
380            variance += weight * (value - oldMean) * (value - mean);
381        }
382        variance = leafIndices.length > 1 ? variance / (leafWeightSum-1) : 0;
383        DimensionTuple leafPred = new DimensionTuple(dimName,mean,variance);
384        return new LeafNode<>(impurityScore,leafPred,Collections.emptyMap(),false);
385    }
386
387    /**
388     * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset
389     * so it's very expensive in terms of memory.
390     * @param examples An input dataset.
391     * @return A list of TreeFeatures which contain {@link InvertedFeature}s.
392     */
393    public static InvertedData invertData(Dataset<Regressor> examples) {
394        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
395        ImmutableOutputInfo<Regressor> labelInfo = examples.getOutputIDInfo();
396        int numLabels = labelInfo.size();
397        int numFeatures = featureInfos.size();
398        int[] indices = new int[examples.size()];
399        float[][] targets = new float[numLabels][examples.size()];
400        float[] weights = new float[examples.size()];
401
402        logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs");
403        ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size());
404
405        for (int i = 0; i < featureInfos.size(); i++) {
406            data.add(new TreeFeature(i));
407        }
408
409        int[] ids = ((ImmutableRegressionInfo) labelInfo).getNaturalOrderToIDMapping();
410        for (int i = 0; i < examples.size(); i++) {
411            Example<Regressor> e = examples.getExample(i);
412            indices[i] = i;
413            weights[i] = e.getWeight();
414            double[] output = e.getOutput().getValues();
415            for (int j = 0; j < output.length; j++) {
416                targets[ids[j]][i] = (float) output[j];
417            }
418            SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false);
419            int lastID = 0;
420            for (VectorTuple f : vec) {
421                int curID = f.index;
422                for (int j = lastID; j < curID; j++) {
423                    data.get(j).observeValue(0.0,i);
424                }
425                data.get(curID).observeValue(f.value,i);
426                //
427                // These two checks should never occur as SparseVector deals with
428                // collisions, and Dataset prevents repeated features.
429                // They are left in just to make sure.
430                if (lastID > curID) {
431                    logger.severe("Example = " + e.toString());
432                    throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
433                } else if (lastID-1 == curID) {
434                    logger.severe("Example = " + e.toString());
435                    throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
436                }
437                lastID = curID + 1;
438            }
439            for (int j = lastID; j < numFeatures; j++) {
440                data.get(j).observeValue(0.0,i);
441            }
442            if (i % 1000 == 0) {
443                logger.fine("Processed example " + i);
444            }
445        }
446
447        logger.fine("Sorting features");
448
449        data.forEach(TreeFeature::sort);
450
451        logger.fine("Fixing InvertedFeature sizes");
452
453        data.forEach(TreeFeature::fixSize);
454
455        logger.fine("Built initial List<TreeFeature>");
456
457        return new InvertedData(data,indices,targets,weights);
458    }
459
460    /**
461     * Tuple containing an inverted dataset (i.e., feature-wise not exmaple-wise).
462     */
463    public static class InvertedData {
464        final ArrayList<TreeFeature> data;
465        final int[] indices;
466        final float[][] targets;
467        final float[] weights;
468
469        InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) {
470            this.data = data;
471            this.indices = indices;
472            this.targets = targets;
473            this.weights = weights;
474        }
475
476        ArrayList<TreeFeature> copyData() {
477            ArrayList<TreeFeature> newData = new ArrayList<>();
478
479            for (TreeFeature f : data) {
480                newData.add(f.deepCopy());
481            }
482
483            return newData;
484        }
485    }
486
487    private void writeObject(java.io.ObjectOutputStream stream)
488            throws IOException {
489        throw new NotSerializableException("RegressorTrainingNode is a runtime class only, and should not be serialized.");
490    }
491}