001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.ensemble;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import com.oracle.labs.mlrg.olcut.config.Options;
021import org.tribuo.Trainer;
022import org.tribuo.classification.Label;
023import org.tribuo.common.tree.DecisionTreeTrainer;
024import org.tribuo.common.tree.ExtraTreesTrainer;
025import org.tribuo.common.tree.RandomForestTrainer;
026import org.tribuo.ensemble.BaggingTrainer;
027
028import java.util.logging.Logger;
029
030/**
031 * Options for building a classification ensemble.
032 */
033public class ClassificationEnsembleOptions implements Options {
034    private static final Logger logger = Logger.getLogger(ClassificationEnsembleOptions.class.getName());
035
036    /**
037     * The type of ensemble.
038     */
039    public enum EnsembleType {
040        /**
041         * Creates an {@link AdaBoostTrainer}.
042         */
043        ADABOOST,
044        /**
045         * Creates a {@link BaggingTrainer}.
046         */
047        BAGGING,
048        /**
049         * Creates an {@link ExtraTreesTrainer}.
050         */
051        EXTRA_TREES,
052        /**
053         * Creates a {@link RandomForestTrainer}.
054         */
055        RF
056    }
057
058    /**
059     * Ensemble method, options are {ADABOOST, BAGGING, RF}.
060     */
061    @Option(longName = "ensemble-type", usage = "Ensemble method, options are {ADABOOST, BAGGING, RF}.")
062    public EnsembleType type = EnsembleType.BAGGING;
063    /**
064     * Number of base learners in the ensemble.
065     */
066    @Option(longName = "ensemble-size", usage = "Number of base learners in the ensemble.")
067    public int ensembleSize = -1;
068    /**
069     * RNG seed.
070     */
071    @Option(longName = "ensemble-seed", usage = "RNG seed.")
072    public long seed = Trainer.DEFAULT_SEED;
073
074    /**
075     * Wraps the supplied trainer using the ensemble trainer described by these options.
076     *
077     * @param trainer The trainer to wrap.
078     * @return An ensemble trainer.
079     */
080    public Trainer<Label> wrapTrainer(Trainer<Label> trainer) {
081        if ((ensembleSize > 0) && (type != null)) {
082            switch (type) {
083                case ADABOOST:
084                    logger.info("Using Adaboost with " + ensembleSize + " members.");
085                    return new AdaBoostTrainer(trainer, ensembleSize, seed);
086                case BAGGING:
087                    logger.info("Using Bagging with " + ensembleSize + " members.");
088                    return new BaggingTrainer<>(trainer, new VotingCombiner(), ensembleSize, seed);
089                case EXTRA_TREES:
090                    if (trainer instanceof DecisionTreeTrainer) {
091                        logger.info("Using Extra Trees with " + ensembleSize + " members.");
092                        return new ExtraTreesTrainer<>((DecisionTreeTrainer<Label>) trainer, new VotingCombiner(), ensembleSize, seed);
093                    } else {
094                        throw new IllegalArgumentException("ExtraTreesTrainer requires a DecisionTreeTrainer");
095                    }
096                case RF:
097                    if (trainer instanceof DecisionTreeTrainer) {
098                        logger.info("Using Random Forests with " + ensembleSize + " members.");
099                        return new RandomForestTrainer<>((DecisionTreeTrainer<Label>) trainer, new VotingCombiner(), ensembleSize, seed);
100                    } else {
101                        throw new IllegalArgumentException("RandomForestTrainer requires a DecisionTreeTrainer");
102                    }
103                default:
104                    throw new IllegalArgumentException("Unknown ensemble type :" + type);
105            }
106        } else {
107            return trainer;
108        }
109    }
110}