001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.config.UsageException;
023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
024import com.oracle.labs.mlrg.olcut.util.Pair;
025import org.tribuo.Dataset;
026import org.tribuo.SparseModel;
027import org.tribuo.SparseTrainer;
028import org.tribuo.data.DataOptions;
029import org.tribuo.regression.RegressionFactory;
030import org.tribuo.regression.Regressor;
031import org.tribuo.regression.evaluation.RegressionEvaluation;
032import org.tribuo.regression.rtree.impurity.MeanAbsoluteError;
033import org.tribuo.regression.rtree.impurity.MeanSquaredError;
034import org.tribuo.regression.rtree.impurity.RegressorImpurity;
035import org.tribuo.util.Util;
036
037import java.io.IOException;
038import java.util.logging.Logger;
039
040/**
041 * Build and run a regression tree for a standard dataset.
042 */
043public class TrainTest {
044
045    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());
046
047    /**
048     * Impurity function.
049     */
050    public enum ImpurityType {
051        /**
052         * Use {@link MeanSquaredError}.
053         */
054        MSE,
055        /**
056         * Use {@link MeanAbsoluteError}.
057         */
058        MAE
059    }
060
061    /**
062     * Type of tree trainer.
063     */
064    public enum TreeType {
065        /**
066         * Creates a {@link CARTRegressionTrainer} which treats
067         * each regression dimension independently.
068         */
069        CART_INDEPENDENT,
070        /**
071         * Creates a {@link CARTJointRegressionTrainer} which
072         * jointly minimises the impurity across all output dimensions.
073         */
074        CART_JOINT
075    }
076
077    /**
078     * Command line options.
079     */
080    public static class RegressionTreeOptions implements Options {
081        @Override
082        public String getOptionsDescription() {
083            return "Trains and tests a CART regression model on the specified datasets.";
084        }
085
086        /**
087         * The data loading options.
088         */
089        public DataOptions general;
090        /**
091         * Character to split the CSV response on to generate multiple regression dimensions. Defaults to ':'.
092         */
093        @Option(longName = "csv-response-split-char", usage = "Character to split the CSV response on to generate multiple regression dimensions. Defaults to ':'.")
094        public char splitChar = ':';
095        /**
096         * Maximum depth in the decision tree.
097         */
098        @Option(charName = 'd', longName = "max-depth", usage = "Maximum depth in the decision tree.")
099        public int depth = 6;
100        /**
101         * Fraction of features in split.
102         */
103        @Option(charName = 'e', longName = "split-fraction", usage = "Fraction of features in split.")
104        public float fraction = 1.0f;
105        /**
106         * Minimum child weight.
107         */
108        @Option(charName = 'm', longName = "min-child-weight", usage = "Minimum child weight.")
109        public float minChildWeight = 5.0f;
110        /**
111         * Minimumum decrease in impurity required in order for the node to be split.
112         */
113        @Option(charName = 'p', longName = "min-impurity-decrease", usage = "Minimumum decrease in impurity required in order for the node to be split.")
114        public float minImpurityDecrease = 0.0f;
115        /**
116         * Whether to choose split points for features at random.
117         */
118        @Option(charName = 'r', longName = "use-random-split-points", usage = "Whether to choose split points for features at random.")
119        public boolean useRandomSplitPoints = false;
120        /**
121         * Normalize the leaf outputs so each leaf sums to 1.0.
122         */
123        @Option(charName = 'n', longName = "normalize", usage = "Normalize the leaf outputs so each leaf sums to 1.0.")
124        public boolean normalize = false;
125        /**
126         * Impurity measure to use. Defaults to MSE.
127         */
128        @Option(charName = 'i', longName = "impurity", usage = "Impurity measure to use. Defaults to MSE.")
129        public ImpurityType impurityType = ImpurityType.MSE;
130        /**
131         * Tree type.
132         */
133        @Option(charName = 't', longName = "tree-type", usage = "Tree type.")
134        public TreeType treeType = TreeType.CART_INDEPENDENT;
135        /**
136         * Prints the decision tree.
137         */
138        @Option(longName = "print-tree", usage = "Prints the decision tree.")
139        public boolean printTree;
140    }
141
142    /**
143     * Runs a TrainTest CLI.
144     * @param args the command line arguments
145     * @throws IOException if there is any error reading the examples.
146     */
147    public static void main(String[] args) throws IOException {
148
149        //
150        // Use the labs format logging.
151        LabsLogFormatter.setAllLogFormatters();
152
153        RegressionTreeOptions o = new RegressionTreeOptions();
154        ConfigurationManager cm;
155        try {
156            cm = new ConfigurationManager(args,o);
157        } catch (UsageException e) {
158            logger.info(e.getMessage());
159            return;
160        }
161
162        RegressionFactory factory = new RegressionFactory(o.splitChar);
163
164        Pair<Dataset<Regressor>,Dataset<Regressor>> data = o.general.load(factory);
165        Dataset<Regressor> train = data.getA();
166        Dataset<Regressor> test = data.getB();
167
168        RegressorImpurity impurity;
169        switch (o.impurityType) {
170            case MAE:
171                impurity = new MeanAbsoluteError();
172                break;
173            case MSE:
174                impurity = new MeanSquaredError();
175                break;
176            default:
177                logger.severe("unknown impurity type " + o.impurityType);
178                return;
179        }
180
181        if (o.general.trainingPath == null || o.general.testingPath == null) {
182            logger.info(cm.usage());
183            return;
184        }
185
186        SparseTrainer<Regressor> trainer;
187        switch (o.treeType) {
188            case CART_INDEPENDENT:
189                    trainer = new CARTRegressionTrainer(o.depth, o.minChildWeight,o.minImpurityDecrease,o.fraction, o.useRandomSplitPoints, impurity,
190                            o.general.seed);
191                break;
192            case CART_JOINT:
193                    trainer = new CARTJointRegressionTrainer(o.depth, o.minChildWeight, o.minImpurityDecrease, o.fraction, o.useRandomSplitPoints,
194                            impurity, o.normalize, o.general.seed);
195                break;
196            default:
197                logger.severe("unknown tree type " + o.treeType);
198                return;
199        }
200
201        logger.info("Training using " + trainer.toString());
202
203        final long trainStart = System.currentTimeMillis();
204        SparseModel<Regressor> model = trainer.train(train);
205        final long trainStop = System.currentTimeMillis();
206
207        logger.info("Finished training regressor " + Util.formatDuration(trainStart,trainStop));
208
209        if (o.printTree) {
210            logger.info(model.toString());
211        }
212
213        logger.info("Selected features: " + model.getActiveFeatures());
214        final long testStart = System.currentTimeMillis();
215        RegressionEvaluation evaluation = factory.getEvaluator().evaluate(model,test);
216        final long testStop = System.currentTimeMillis();
217        logger.info("Finished evaluating model " + Util.formatDuration(testStart,testStop));
218        System.out.println(evaluation.toString());
219
220        if (o.general.outputPath != null) {
221            o.general.saveModel(model);
222        }
223    }
224}