001/*
002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Trainer;
026import org.tribuo.common.tree.AbstractCARTTrainer;
027import org.tribuo.common.tree.AbstractTrainingNode;
028import org.tribuo.common.tree.Node;
029import org.tribuo.common.tree.TreeModel;
030import org.tribuo.provenance.ModelProvenance;
031import org.tribuo.provenance.TrainerProvenance;
032import org.tribuo.provenance.impl.TrainerProvenanceImpl;
033import org.tribuo.regression.Regressor;
034import org.tribuo.regression.rtree.impl.RegressorTrainingNode;
035import org.tribuo.regression.rtree.impl.RegressorTrainingNode.InvertedData;
036import org.tribuo.regression.rtree.impurity.MeanSquaredError;
037import org.tribuo.regression.rtree.impurity.RegressorImpurity;
038import org.tribuo.util.Util;
039
040import java.time.OffsetDateTime;
041import java.util.ArrayDeque;
042import java.util.Deque;
043import java.util.HashMap;
044import java.util.List;
045import java.util.Map;
046import java.util.Set;
047import java.util.SplittableRandom;
048
049/**
050 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree.
051 * Trains an independent tree for each output dimension.
052 * <p>
053 * See:
054 * <pre>
055 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
056 * "The Elements of Statistical Learning"
057 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
058 * </pre>
059 */
060public final class CARTRegressionTrainer extends AbstractCARTTrainer<Regressor> {
061
062    /**
063     * Impurity measure used to determine split quality.
064     */
065    @Config(description="Regression impurity measure used to determine split quality.")
066    private RegressorImpurity impurity = new MeanSquaredError();
067
068    /**
069     * Creates a CART Trainer.
070     *
071     * @param maxDepth maxDepth The maximum depth of the tree.
072     * @param minChildWeight minChildWeight The minimum node weight to consider it for a split.
073     * @param minImpurityDecrease The minimum decrease in impurity necessary to split a node.
074     * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split.
075     * @param useRandomSplitPoints Whether to choose split points for features at random.
076     * @param impurity impurity The impurity function to use to determine split quality.
077     * @param seed The RNG seed.
078     */
079    public CARTRegressionTrainer(
080            int maxDepth,
081            float minChildWeight,
082            float minImpurityDecrease,
083            float fractionFeaturesInSplit,
084            boolean useRandomSplitPoints,
085            RegressorImpurity impurity,
086            long seed
087    ) {
088        super(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, useRandomSplitPoints, seed);
089        this.impurity = impurity;
090        postConfig();
091    }
092
093    /**
094     * Creates a CART Trainer.
095     * <p>
096     * Computes the exact split point.
097     * @param maxDepth maxDepth The maximum depth of the tree.
098     * @param minChildWeight minChildWeight The minimum node weight to consider it for a split.
099     * @param minImpurityDecrease The minimum decrease in impurity necessary to split a node.
100     * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split.
101     * @param impurity impurity The impurity function to use to determine split quality.
102     * @param seed The RNG seed.
103     */
104    public CARTRegressionTrainer(
105            int maxDepth,
106            float minChildWeight,
107            float minImpurityDecrease,
108            float fractionFeaturesInSplit,
109            RegressorImpurity impurity,
110            long seed
111    ) {
112        this(maxDepth,minChildWeight,minImpurityDecrease,fractionFeaturesInSplit,false,impurity,seed);
113    }
114
115    /**
116     * Creates a CART trainer.
117     * <p>
118     * Sets the impurity to the {@link MeanSquaredError}, uses
119     * all the features, computes the exact split point, and
120     * sets the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}.
121     */
122    public CARTRegressionTrainer() {
123        this(Integer.MAX_VALUE);
124    }
125
126    /**
127     * Creates a CART trainer.
128     * <p>
129     * Sets the impurity to the {@link MeanSquaredError}, uses
130     * all the features, computes the exact split point and sets
131     * the minimum number of examples in a leaf to {@link #MIN_EXAMPLES}.
132     * @param maxDepth The maximum depth of the tree.
133     */
134    public CARTRegressionTrainer(int maxDepth) {
135        this(maxDepth, MIN_EXAMPLES, 0.0f, 1.0f, false, new MeanSquaredError(), Trainer.DEFAULT_SEED);
136    }
137
138    @Override
139    protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples,
140                                                             AbstractTrainingNode.LeafDeterminer leafDeterminer) {
141        throw new IllegalStateException("Shouldn't reach here.");
142    }
143
144    @Override
145    public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
146        return train(examples, runProvenance, INCREMENT_INVOCATION_COUNT);
147    }
148
149    @Override
150    public TreeModel<Regressor> train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
151        if (examples.getOutputInfo().getUnknownCount() > 0) {
152            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
153        }
154        // Creates a new RNG, adds one to the invocation count.
155        SplittableRandom localRNG;
156        TrainerProvenance trainerProvenance;
157        synchronized(this) {
158            if(invocationCount != INCREMENT_INVOCATION_COUNT) {
159                setInvocationCount(invocationCount);
160            }
161            localRNG = rng.split();
162            trainerProvenance = getProvenance();
163            trainInvocationCounter++;
164        }
165
166        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
167        ImmutableOutputInfo<Regressor> outputIDInfo = examples.getOutputIDInfo();
168        Set<Regressor> domain = outputIDInfo.getDomain();
169
170        int numFeaturesInSplit = Math.min(Math.round(fractionFeaturesInSplit * featureIDMap.size()),featureIDMap.size());
171        int[] indices;
172        int[] originalIndices = new int[featureIDMap.size()];
173        for (int i = 0; i < originalIndices.length; i++) {
174            originalIndices[i] = i;
175        }
176        if (numFeaturesInSplit != featureIDMap.size()) {
177            indices = new int[numFeaturesInSplit];
178        } else {
179            indices = originalIndices;
180        }
181
182        float weightSum = 0.0f;
183        for (Example<Regressor> e : examples) {
184            weightSum += e.getWeight();
185        }
186        float scaledMinImpurityDecrease = getMinImpurityDecrease() * weightSum;
187        AbstractTrainingNode.LeafDeterminer leafDeterminer = new AbstractTrainingNode.LeafDeterminer(maxDepth,
188                minChildWeight, scaledMinImpurityDecrease);
189
190        InvertedData data = RegressorTrainingNode.invertData(examples);
191
192        Map<String, Node<Regressor>> nodeMap = new HashMap<>();
193        for (Regressor r : domain) {
194            String dimName = r.getNames()[0];
195            int dimIdx = outputIDInfo.getID(r);
196
197            AbstractTrainingNode<Regressor> root = new RegressorTrainingNode(impurity,data,dimIdx,dimName,
198                    examples.size(),featureIDMap,outputIDInfo, leafDeterminer);
199            Deque<AbstractTrainingNode<Regressor>> queue = new ArrayDeque<>();
200            queue.add(root);
201
202            while (!queue.isEmpty()) {
203                AbstractTrainingNode<Regressor> node = queue.poll();
204                if ((node.getImpurity() > 0.0) && (node.getDepth() < maxDepth) &&
205                        (node.getWeightSum() >= minChildWeight)) {
206                    if (numFeaturesInSplit != featureIDMap.size()) {
207                        Util.randpermInPlace(originalIndices, localRNG);
208                        System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit);
209                    }
210                    List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices, localRNG,
211                            getUseRandomSplitPoints());
212                    // Use the queue as a stack to improve cache locality.
213                    for (AbstractTrainingNode<Regressor> newNode : nodes) {
214                        queue.addFirst(newNode);
215                    }
216                }
217            }
218
219            nodeMap.put(dimName,root.convertTree());
220        }
221
222        ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
223        return new IndependentRegressionTreeModel("cart-tree", provenance, featureIDMap, outputIDInfo, false, nodeMap);
224    }
225
226    @Override
227    public String toString() {
228        StringBuilder buffer = new StringBuilder();
229
230        buffer.append("CARTRegressionTrainer(maxDepth=");
231        buffer.append(maxDepth);
232        buffer.append(",minChildWeight=");
233        buffer.append(minChildWeight);
234        buffer.append(",minImpurityDecrease=");
235        buffer.append(minImpurityDecrease);
236        buffer.append(",fractionFeaturesInSplit=");
237        buffer.append(fractionFeaturesInSplit);
238        buffer.append(",useRandomSplitPoints=");
239        buffer.append(useRandomSplitPoints);
240        buffer.append(",impurity=");
241        buffer.append(impurity.toString());
242        buffer.append(",seed=");
243        buffer.append(seed);
244        buffer.append(")");
245
246        return buffer.toString();
247    }
248
249    @Override
250    public TrainerProvenance getProvenance() {
251        return new TrainerProvenanceImpl(this);
252    }
253}