001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import org.tribuo.Dataset;
021import org.tribuo.Trainer;
022import org.tribuo.common.tree.AbstractCARTTrainer;
023import org.tribuo.common.tree.AbstractTrainingNode;
024import org.tribuo.provenance.TrainerProvenance;
025import org.tribuo.provenance.impl.TrainerProvenanceImpl;
026import org.tribuo.regression.Regressor;
027import org.tribuo.regression.rtree.impl.JointRegressorTrainingNode;
028import org.tribuo.regression.rtree.impurity.MeanSquaredError;
029import org.tribuo.regression.rtree.impurity.RegressorImpurity;
030
031/**
032 * A {@link org.tribuo.Trainer} that uses an approximation of the CART algorithm to build a decision tree.
033 * <p>
034 * Builds a single tree for all the regression dimensions.
035 * <p>
036 * See:
037 * <pre>
038 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
039 * "The Elements of Statistical Learning"
040 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
041 * </pre>
042 */
043public class CARTJointRegressionTrainer extends AbstractCARTTrainer<Regressor> {
044
045    /**
046     * Impurity measure used to determine split quality.
047     */
048    @Config(description="The regression impurity to use.")
049    private RegressorImpurity impurity = new MeanSquaredError();
050
051    /**
052     * Normalizes the output of each leaf so it sums to one (i.e., is a probability distribution).
053     */
054    @Config(description="Normalize the output of each leaf so it sums to one.")
055    private boolean normalize = false;
056
057    /**
058     * Creates a CART Trainer.
059     *
060     * @param maxDepth maxDepth The maximum depth of the tree.
061     * @param minChildWeight minChildWeight The minimum node weight to consider it for a split.
062     * @param minImpurityDecrease The minimum decrease in impurity necessary to split a node.
063     * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split.
064     * @param useRandomSplitPoints Whether to choose split points for features at random.
065     * @param impurity impurity The impurity function to use to determine split quality.
066     * @param normalize Normalize the leaves so each output sums to one.
067     * @param seed The seed to use for the RNG.
068     */
069    public CARTJointRegressionTrainer(
070            int maxDepth,
071            float minChildWeight,
072            float minImpurityDecrease,
073            float fractionFeaturesInSplit,
074            boolean useRandomSplitPoints,
075            RegressorImpurity impurity,
076            boolean normalize,
077            long seed
078    ) {
079        super(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, useRandomSplitPoints, seed);
080        this.impurity = impurity;
081        this.normalize = normalize;
082        postConfig();
083    }
084
085    /**
086     * Creates a CART Trainer.
087     * <p>
088     * Computes the exact split point.
089     * @param maxDepth maxDepth The maximum depth of the tree.
090     * @param minChildWeight minChildWeight The minimum node weight to consider it for a split.
091     * @param minImpurityDecrease The minimum decrease in impurity necessary to split a node.
092     * @param fractionFeaturesInSplit fractionFeaturesInSplit The fraction of features available in each split.
093     * @param impurity impurity The impurity function to use to determine split quality.
094     * @param normalize Normalize the leaves so each output sums to one.
095     * @param seed The seed to use for the RNG.
096     */
097    public CARTJointRegressionTrainer(
098            int maxDepth,
099            float minChildWeight,
100            float minImpurityDecrease,
101            float fractionFeaturesInSplit,
102            RegressorImpurity impurity,
103            boolean normalize,
104            long seed
105    ) {
106        this(maxDepth, minChildWeight, minImpurityDecrease, fractionFeaturesInSplit, false, impurity, normalize, seed);
107    }
108
109    /**
110     * Creates a CART Trainer.
111     * <p>
112     * Sets the impurity to the {@link MeanSquaredError}, computes an arbitrary depth
113     * tree with exact split points using all the features, and does not normalize the outputs.
114     */
115    public CARTJointRegressionTrainer() {
116        this(Integer.MAX_VALUE, MIN_EXAMPLES, 0.0f, 1.0f, false, new MeanSquaredError(), false, Trainer.DEFAULT_SEED);
117    }
118
119    /**
120     * Creates a CART Trainer.
121     * <p>
122     * Sets the impurity to the {@link MeanSquaredError}, computes the exact split
123     * points using all the features, and does not normalize the outputs.
124     * @param maxDepth The maximum depth of the tree.
125     */
126    public CARTJointRegressionTrainer(int maxDepth) {
127        this(maxDepth, MIN_EXAMPLES, 0.0f, 1.0f, false, new MeanSquaredError(), false, Trainer.DEFAULT_SEED);
128    }
129
130    /**
131     * Creates a CART Trainer. Sets the impurity to the {@link MeanSquaredError}.
132     * @param maxDepth The maximum depth of the tree.
133     * @param normalize Normalises the leaves so each leaf has a distribution which sums to 1.0.
134     */
135    public CARTJointRegressionTrainer(int maxDepth, boolean normalize) {
136        this(maxDepth, MIN_EXAMPLES, 0.0f, 1.0f, false, new MeanSquaredError(), normalize, Trainer.DEFAULT_SEED);
137    }
138
139    @Override
140    protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples,
141                                                             AbstractTrainingNode.LeafDeterminer leafDeterminer) {
142        return new JointRegressorTrainingNode(impurity, examples, normalize, leafDeterminer);
143    }
144
145    @Override
146    public String toString() {
147        StringBuilder buffer = new StringBuilder();
148
149        buffer.append("CARTJointRegressionTrainer(maxDepth=");
150        buffer.append(maxDepth);
151        buffer.append(",minChildWeight=");
152        buffer.append(minChildWeight);
153        buffer.append(",minImpurityDecrease=");
154        buffer.append(minImpurityDecrease);
155        buffer.append(",fractionFeaturesInSplit=");
156        buffer.append(fractionFeaturesInSplit);
157        buffer.append(",useRandomSplitPoints=");
158        buffer.append(useRandomSplitPoints);
159        buffer.append(",impurity=");
160        buffer.append(impurity.toString());
161        buffer.append(",normalize=");
162        buffer.append(normalize);
163        buffer.append(",seed=");
164        buffer.append(seed);
165        buffer.append(")");
166
167        return buffer.toString();
168    }
169
170    @Override
171    public TrainerProvenance getProvenance() {
172        return new TrainerProvenanceImpl(this);
173    }
174}