001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree.impl;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.common.tree.AbstractTrainingNode;
025import org.tribuo.common.tree.LeafNode;
026import org.tribuo.common.tree.Node;
027import org.tribuo.common.tree.SplitNode;
028import org.tribuo.common.tree.impl.IntArrayContainer;
029import org.tribuo.math.la.SparseVector;
030import org.tribuo.math.la.VectorTuple;
031import org.tribuo.regression.ImmutableRegressionInfo;
032import org.tribuo.regression.Regressor;
033import org.tribuo.regression.rtree.impurity.RegressorImpurity;
034import org.tribuo.regression.rtree.impurity.RegressorImpurity.ImpurityTuple;
035import org.tribuo.util.Util;
036
037import java.io.IOException;
038import java.io.NotSerializableException;
039import java.util.ArrayList;
040import java.util.Collections;
041import java.util.List;
042import java.util.SplittableRandom;
043import java.util.logging.Logger;
044
045/**
046 * A decision tree node used at training time.
047 * Contains a list of the example indices currently found in this node,
048 * the current impurity and a bunch of other statistics.
049 */
050public class JointRegressorTrainingNode extends AbstractTrainingNode<Regressor> {
051    private static final long serialVersionUID = 1L;
052
053    private static final Logger logger = Logger.getLogger(JointRegressorTrainingNode.class.getName());
054
055    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
056    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
057
058    private transient ArrayList<TreeFeature> data;
059
060    private final boolean normalize;
061
062    private final ImmutableOutputInfo<Regressor> labelIDMap;
063
064    private final ImmutableFeatureMap featureIDMap;
065
066    private final RegressorImpurity impurity;
067
068    private final int[] indices;
069
070    private final float[][] targets;
071
072    private final float[] weights;
073
074    private final float weightSum;
075
076    /**
077     * Constructor which creates the inverted file.
078     * @param impurity The impurity function to use.
079     * @param examples The training data.
080     * @param normalize Normalizes the leaves so each leaf has a distribution which sums to 1.0.
081     * @param leafDeterminer Contains parameters needed to determine whether a node is a leaf.
082     */
083    public JointRegressorTrainingNode(RegressorImpurity impurity, Dataset<Regressor> examples, boolean normalize,
084                                      LeafDeterminer leafDeterminer) {
085        this(impurity, invertData(examples), examples.size(), examples.getFeatureIDMap(), examples.getOutputIDInfo(),
086                normalize, leafDeterminer);
087    }
088
089    private JointRegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int numExamples,
090                                       ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo,
091                                       boolean normalize, LeafDeterminer leafDeterminer) {
092        this(impurity,tuple.data,tuple.indices,tuple.targets,tuple.weights,numExamples,0,featureIDMap,outputInfo,
093                normalize, leafDeterminer);
094    }
095
096    private JointRegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices,
097                                       float[][] targets, float[] weights, int numExamples, int depth,
098                                       ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap,
099                                       boolean normalize, LeafDeterminer leafDeterminer) {
100        super(depth, numExamples, leafDeterminer);
101        this.data = data;
102        this.normalize = normalize;
103        this.featureIDMap = featureIDMap;
104        this.labelIDMap = labelIDMap;
105        this.impurity = impurity;
106        this.indices = indices;
107        this.targets = targets;
108        this.weights = weights;
109        this.weightSum = Util.sum(indices,indices.length,weights);
110        this.impurityScore = calcImpurity(indices);
111    }
112
113    private JointRegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices,
114                                       float[][] targets, float[] weights, int numExamples, int depth,
115                                       ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap,
116                                       boolean normalize, LeafDeterminer leafDeterminer, float weightSum,
117                                       double impurityScore) {
118        super(depth, numExamples, leafDeterminer);
119        this.data = data;
120        this.normalize = normalize;
121        this.featureIDMap = featureIDMap;
122        this.labelIDMap = labelIDMap;
123        this.impurity = impurity;
124        this.indices = indices;
125        this.targets = targets;
126        this.weights = weights;
127        this.weightSum = weightSum;
128        this.impurityScore = impurityScore;
129    }
130
131    @Override
132    public double getImpurity() {
133        return impurityScore;
134    }
135
136    @Override
137    public float getWeightSum() {
138        return weightSum;
139    }
140
141    /**
142     * Calculates the impurity score of the node.
143     * @return The impurity score of the node.
144     */
145    private double calcImpurity(int[] curIndices) {
146        double tmp = 0.0;
147        for (int i = 0; i < targets.length; i++) {
148            tmp += impurity.impurity(curIndices, targets[i], weights);
149        }
150        return tmp / targets.length;
151    }
152
153    /**
154     * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5).
155     * @param featureIDs Indices of the features available in this split.
156     * @param rng Splittable random number generator.
157     * @param useRandomSplitPoints Whether to choose split points for features at random.
158     * @return A possibly empty list of TrainingNodes.
159     */
160    @Override
161    public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs, SplittableRandom rng,
162                                                           boolean useRandomSplitPoints) {
163        if (useRandomSplitPoints) {
164            return buildRandomTree(featureIDs, rng);
165        } else {
166            return buildGreedyTree(featureIDs);
167        }
168    }
169
170    /**
171     * Builds a tree according to CART
172     * @param featureIDs Indices of the features available in this split.
173     * @return A possibly empty list of TrainingNodes.
174     */
175    private List<AbstractTrainingNode<Regressor>> buildGreedyTree(int[] featureIDs) {
176        int bestID = -1;
177        double bestSplitValue = 0.0;
178        double bestScore = getImpurity();
179        //logger.info("Cur node score = " + bestScore);
180        List<int[]> curIndices = new ArrayList<>();
181        List<int[]> bestLeftIndices = new ArrayList<>();
182        List<int[]> bestRightIndices = new ArrayList<>();
183        for (int i = 0; i < featureIDs.length; i++) {
184            List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature();
185
186            curIndices.clear();
187            for (int j = 0; j < feature.size(); j++) {
188                InvertedFeature f = feature.get(j);
189                int[] curFeatureIndices = f.indices();
190                curIndices.add(curFeatureIndices);
191            }
192
193            // searching for the intervals between features.
194            for (int j = 0; j < feature.size()-1; j++) {
195                List<int[]> curLeftIndices = curIndices.subList(0,j+1);
196                List<int[]> curRightIndices = curIndices.subList(j+1,feature.size());
197                double lessThanScore = 0.0;
198                double greaterThanScore = 0.0;
199                for (int k = 0; k < targets.length; k++) {
200                    ImpurityTuple left = impurity.impurityTuple(curLeftIndices,targets[k],weights);
201                    lessThanScore += left.impurity * left.weight;
202                    ImpurityTuple right = impurity.impurityTuple(curRightIndices,targets[k],weights);
203                    greaterThanScore += right.impurity * right.weight;
204                }
205                double score = (lessThanScore + greaterThanScore) / (targets.length * weightSum);
206                if (score < bestScore) {
207                    bestID = i;
208                    bestScore = score;
209                    bestSplitValue = (feature.get(j).value + feature.get(j + 1).value) / 2.0;
210                    // Clear out the old best indices before storing the new ones.
211                    bestLeftIndices.clear();
212                    bestLeftIndices.addAll(curLeftIndices);
213                    bestRightIndices.clear();
214                    bestRightIndices.addAll(curRightIndices);
215                    //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score);
216                    //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size);
217                }
218            }
219        }
220        List<AbstractTrainingNode<Regressor>> output;
221        double impurityDecrease = weightSum * (getImpurity() - bestScore);
222        // If we found a split better than the current impurity.
223        if ((bestID != -1) && (impurityDecrease >= leafDeterminer.getScaledMinImpurityDecrease())) {
224            output = splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices);
225        } else {
226            output = Collections.emptyList();
227        }
228        data = null;
229        return output;
230    }
231
232    /**
233     * Builds a CART tree with randomly chosen split points.
234     * @param featureIDs Indices of the features available in this split.
235     * @param rng Splittable random number generator.
236     * @return A possibly empty list of TrainingNodes.
237     */
238    private List<AbstractTrainingNode<Regressor>> buildRandomTree(int[] featureIDs, SplittableRandom rng) {
239        int bestID = -1;
240        double bestSplitValue = 0.0;
241        double bestScore = getImpurity();
242        //logger.info("Cur node score = " + bestScore);
243        List<int[]> curLeftIndices = new ArrayList<>();
244        List<int[]> curRightIndices = new ArrayList<>();
245        List<int[]> bestLeftIndices = new ArrayList<>();
246        List<int[]> bestRightIndices = new ArrayList<>();
247
248        // split each feature once randomly and record the least impure amongst these
249        for (int i = 0; i < featureIDs.length; i++) {
250            List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature();
251            // if there is only 1 inverted feature for this feature, it has only 1 value, so cannot be split
252            if (feature.size() == 1) {
253                continue;
254            }
255
256            double lessThanScore = 0.0;
257            double greaterThanScore = 0.0;
258
259            int splitIdx = rng.nextInt(feature.size()-1);
260
261            for (int j = 0; j < splitIdx + 1; j++) {
262                InvertedFeature vf;
263                vf = feature.get(j);
264                curLeftIndices.add(vf.indices());
265            }
266            for (int j = splitIdx + 1; j < feature.size(); j++) {
267                InvertedFeature vf;
268                vf = feature.get(j);
269                curRightIndices.add(vf.indices());
270            }
271
272            for (int k = 0; k < targets.length; k++) {
273                ImpurityTuple left = impurity.impurityTuple(curLeftIndices,targets[k],weights);
274                lessThanScore += left.impurity * left.weight;
275                ImpurityTuple right = impurity.impurityTuple(curRightIndices,targets[k],weights);
276                greaterThanScore += right.impurity * right.weight;
277            }
278
279            double score = (lessThanScore + greaterThanScore) / (targets.length * weightSum);
280            if (score < bestScore) {
281                bestID = i;
282                bestScore = score;
283                bestSplitValue = (feature.get(splitIdx).value + feature.get(splitIdx + 1).value) / 2.0;
284                // Clear out the old best indices before storing the new ones.
285                bestLeftIndices.clear();
286                bestLeftIndices.addAll(curLeftIndices);
287                bestRightIndices.clear();
288                bestRightIndices.addAll(curRightIndices);
289                //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score);
290                //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size);
291            }
292        }
293
294        List<AbstractTrainingNode<Regressor>> output;
295        double impurityDecrease = weightSum * (getImpurity() - bestScore);
296        // If we found a split better than the current impurity.
297        if ((bestID != -1) && (impurityDecrease >= leafDeterminer.getScaledMinImpurityDecrease())) {
298            output = splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices);
299        } else {
300            output = Collections.emptyList();
301        }
302        data = null;
303        return output;
304    }
305
306    /**
307     * Splits the data to form two nodes.
308     * @param featureIDs Indices of the features available in this split.
309     * @param bestID ID of the feature on which the split should be based.
310     * @param bestSplitValue Feature value to use for splitting the data.
311     * @param bestLeftIndices The indices of the examples less than or equal to the split value for the given feature.
312     * @param bestRightIndices The indices of the examples greater than the split value for the given feature.
313     * @return A list of training nodes resulting from the split.
314     */
315    private List<AbstractTrainingNode<Regressor>> splitAtBest(int[] featureIDs, int bestID, double bestSplitValue,
316                                                             List<int[]> bestLeftIndices, List<int[]> bestRightIndices){
317        splitID = featureIDs[bestID];
318        split = true;
319        splitValue = bestSplitValue;
320        IntArrayContainer firstBuffer = mergeBufferOne.get();
321        firstBuffer.size = 0;
322        firstBuffer.grow(indices.length);
323        IntArrayContainer secondBuffer = mergeBufferTwo.get();
324        secondBuffer.size = 0;
325        secondBuffer.grow(indices.length);
326        int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, firstBuffer, secondBuffer);
327        int[] rightIndices = IntArrayContainer.merge(bestRightIndices, firstBuffer, secondBuffer);
328
329        float leftWeightSum = Util.sum(leftIndices,leftIndices.length,weights);
330        double leftImpurityScore = calcImpurity(leftIndices);
331
332        float rightWeightSum = Util.sum(rightIndices,rightIndices.length,weights);
333        double rightImpurityScore = calcImpurity(rightIndices);
334
335        boolean shouldMakeLeftLeaf = shouldMakeLeaf(leftImpurityScore, leftWeightSum);
336        boolean shouldMakeRightLeaf = shouldMakeLeaf(rightImpurityScore, rightWeightSum);
337
338        if (shouldMakeLeftLeaf && shouldMakeRightLeaf) {
339            lessThanOrEqual = createLeaf(leftImpurityScore, leftIndices);
340            greaterThan = createLeaf(rightImpurityScore, rightIndices);
341            return Collections.emptyList();
342        }
343        //logger.info("Splitting on feature " + bestID + " with value " + bestSplitValue + " at depth " + depth + ", " + numExamples + " examples in node.");
344        //logger.info("left indices length = " + leftIndices.length);
345        ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size());
346        ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size());
347        for (TreeFeature feature : data) {
348            Pair<TreeFeature,TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer);
349            lessThanData.add(split.getA());
350            greaterThanData.add(split.getB());
351        }
352
353        List<AbstractTrainingNode<Regressor>> output = new ArrayList<>(2);
354        AbstractTrainingNode<Regressor> tmpNode;
355        if (shouldMakeLeftLeaf) {
356            lessThanOrEqual = createLeaf(leftImpurityScore, leftIndices);
357        } else {
358            tmpNode = new JointRegressorTrainingNode(impurity, lessThanData, leftIndices, targets,
359                    weights, leftIndices.length, depth + 1, featureIDMap, labelIDMap, normalize, leafDeterminer,
360                    leftWeightSum, leftImpurityScore);
361            lessThanOrEqual = tmpNode;
362            output.add(tmpNode);
363        }
364
365        if (shouldMakeRightLeaf) {
366            greaterThan = createLeaf(rightImpurityScore, rightIndices);
367        } else {
368            tmpNode = new JointRegressorTrainingNode(impurity, greaterThanData, rightIndices, targets,
369                    weights, rightIndices.length, depth + 1, featureIDMap, labelIDMap, normalize, leafDeterminer,
370                    rightWeightSum, rightImpurityScore);
371            greaterThan = tmpNode;
372            output.add(tmpNode);
373        }
374        return output;
375    }
376
377    /**
378     * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node.
379     * @return A subtree using the SplitNode and LeafNode classes.
380     */
381    @Override
382    public Node<Regressor> convertTree() {
383        if (split) {
384            return createSplitNode();
385        } else {
386            return createLeaf(getImpurity(), indices);
387        }
388    }
389
390    /**
391     * Makes a {@link LeafNode}
392     * @param impurityScore the impurity score for the node.
393     * @param leafIndices the indices of the examples to be placed in the node.
394     * @return A {@link LeafNode}
395     */
396    private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) {
397        double leafWeightSum = 0.0;
398        double[] mean = new double[targets.length];
399        Regressor leafPred;
400        if (normalize) {
401            for (int i = 0; i < leafIndices.length; i++) {
402                int idx = leafIndices[i];
403                float weight = weights[idx];
404                leafWeightSum += weight;
405                for (int j = 0; j < targets.length; j++) {
406                    float value = targets[j][idx];
407
408                    double oldMean = mean[j];
409                    mean[j] += (weight / leafWeightSum) * (value - oldMean);
410                }
411            }
412            String[] names = new String[targets.length];
413            double sum = 0.0;
414            for (int i = 0; i < targets.length; i++) {
415                names[i] = labelIDMap.getOutput(i).getNames()[0];
416                sum += mean[i];
417            }
418            // Normalize all the outputs so that they sum to 1.0.
419            for (int i = 0; i < targets.length; i++) {
420                mean[i] /= sum;
421            }
422            // Both names and mean are in id order, so the regressor constructor
423            // will convert them to natural order if they are different.
424            leafPred = new Regressor(names, mean);
425        } else {
426            double[] variance = new double[targets.length];
427            for (int i = 0; i < leafIndices.length; i++) {
428                int idx = leafIndices[i];
429                float weight = weights[idx];
430                leafWeightSum += weight;
431                for (int j = 0; j < targets.length; j++) {
432                    float value = targets[j][idx];
433
434                    double oldMean = mean[j];
435                    mean[j] += (weight / leafWeightSum) * (value - oldMean);
436                    variance[j] += weight * (value - oldMean) * (value - mean[j]);
437                }
438            }
439            String[] names = new String[targets.length];
440            for (int i = 0; i < targets.length; i++) {
441                names[i] = labelIDMap.getOutput(i).getNames()[0];
442                variance[i] = leafIndices.length > 1 ? variance[i] / (leafWeightSum - 1) : 0;
443            }
444            // Both names, mean and variance are in id order, so the regressor constructor
445            // will convert them to natural order if they are different.
446            leafPred = new Regressor(names, mean, variance);
447        }
448        return new LeafNode<>(impurityScore,leafPred,Collections.emptyMap(),false);
449    }
450
451    /**
452     * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset
453     * so it's very expensive in terms of memory.
454     * @param examples An input dataset.
455     * @return A list of TreeFeatures which contain {@link InvertedFeature}s.
456     */
457    private static InvertedData invertData(Dataset<Regressor> examples) {
458        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
459        ImmutableOutputInfo<Regressor> labelInfo = examples.getOutputIDInfo();
460        int numLabels = labelInfo.size();
461        int numFeatures = featureInfos.size();
462        int[] indices = new int[examples.size()];
463        float[][] targets = new float[labelInfo.size()][examples.size()];
464        float[] weights = new float[examples.size()];
465
466        logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs");
467        ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size());
468
469        for (int i = 0; i < featureInfos.size(); i++) {
470            data.add(new TreeFeature(i));
471        }
472
473        int[] ids = ((ImmutableRegressionInfo) labelInfo).getNaturalOrderToIDMapping();
474        for (int i = 0; i < examples.size(); i++) {
475            Example<Regressor> e = examples.getExample(i);
476            indices[i] = i;
477            weights[i] = e.getWeight();
478            double[] output = e.getOutput().getValues();
479            for (int j = 0; j < targets.length; j++) {
480                targets[ids[j]][i] = (float) output[j];
481            }
482            SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false);
483            int lastID = 0;
484            for (VectorTuple f : vec) {
485                int curID = f.index;
486                for (int j = lastID; j < curID; j++) {
487                    data.get(j).observeValue(0.0,i);
488                }
489                data.get(curID).observeValue(f.value,i);
490                //
491                // These two checks should never occur as SparseVector deals with
492                // collisions, and Dataset prevents repeated features.
493                // They are left in just to make sure.
494                if (lastID > curID) {
495                    logger.severe("Example = " + e.toString());
496                    throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
497                } else if (lastID-1 == curID) {
498                    logger.severe("Example = " + e.toString());
499                    throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
500                }
501                lastID = curID + 1;
502            }
503            for (int j = lastID; j < numFeatures; j++) {
504                data.get(j).observeValue(0.0,i);
505            }
506            if (i % 1000 == 0) {
507                logger.fine("Processed example " + i);
508            }
509        }
510
511        logger.fine("Sorting features");
512
513        data.forEach(TreeFeature::sort);
514
515        logger.fine("Fixing InvertedFeature sizes");
516
517        data.forEach(TreeFeature::fixSize);
518
519        logger.fine("Built initial List<TreeFeature>");
520
521        return new InvertedData(data,indices,targets,weights);
522    }
523
524    private static class InvertedData {
525        final ArrayList<TreeFeature> data;
526        final int[] indices;
527        final float[][] targets;
528        final float[] weights;
529
530        InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) {
531            this.data = data;
532            this.indices = indices;
533            this.targets = targets;
534            this.weights = weights;
535        }
536    }
537
538    private void writeObject(java.io.ObjectOutputStream stream)
539            throws IOException {
540        throw new NotSerializableException("JointRegressorTrainingNode is a runtime class only, and should not be serialized.");
541    }
542}