001/*
002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.baseline;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import org.tribuo.Dataset;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.MutableOutputInfo;
027import org.tribuo.Trainer;
028import org.tribuo.classification.Label;
029import org.tribuo.provenance.ModelProvenance;
030import org.tribuo.provenance.TrainerProvenance;
031import org.tribuo.provenance.impl.TrainerProvenanceImpl;
032
033import java.time.OffsetDateTime;
034import java.util.Map;
035
036/**
037 * A trainer for simple baseline classifiers. Use this only for comparison purposes, if you can't beat these
038 * baselines, your ML system doesn't work.
039 */
040public final class DummyClassifierTrainer implements Trainer<Label> {
041
042    /**
043     * Types of dummy classifier.
044     */
045    public enum DummyType {
046        /**
047         * Samples the label proprotional to the training label frequencies.
048         */
049        STRATIFIED,
050        /**
051         * Returns the most frequent training label.
052         */
053        MOST_FREQUENT,
054        /**
055         * Samples uniformly from the label domain.
056         */
057        UNIFORM,
058        /**
059         * Returns the supplied label for all inputs.
060         */
061        CONSTANT
062    }
063
064    @Config(mandatory = true,description="Type of dummy classifier.")
065    private DummyType dummyType;
066
067    @Config(description="Label to use for the constant classifier.")
068    private String constantLabel;
069
070    @Config(description="Seed for the RNG.")
071    private long seed = 1L;
072
073    private int invocationCount = 0;
074
075    private DummyClassifierTrainer() {}
076
077    /**
078     * Used by the OLCUT configuration system, and should not be called by external code.
079     */
080    @Override
081    public void postConfig() {
082        if ((dummyType == DummyType.CONSTANT) && (constantLabel == null)) {
083            throw new PropertyException("","constantLabel","Please supply a label string when using the type CONSTANT.");
084        }
085    }
086
087    @Override
088    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance) {
089        return train(examples, instanceProvenance, INCREMENT_INVOCATION_COUNT) ;
090    }
091
092    @Override
093    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance, int invocationCount) {
094        if(invocationCount != INCREMENT_INVOCATION_COUNT) {
095            this.invocationCount = invocationCount;
096        }
097        ModelProvenance provenance = new ModelProvenance(DummyClassifierModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance);
098        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
099        this.invocationCount++;
100        switch (dummyType) {
101            case CONSTANT:
102                MutableOutputInfo<Label> labelInfo = examples.getOutputInfo().generateMutableOutputInfo();
103                Label constLabel = new Label(constantLabel);
104                labelInfo.observe(constLabel);
105                return new DummyClassifierModel(provenance,featureMap,labelInfo.generateImmutableOutputInfo(),constLabel);
106            case MOST_FREQUENT: {
107                ImmutableOutputInfo<Label> immutableLabelInfo = examples.getOutputIDInfo();
108                return new DummyClassifierModel(provenance, featureMap, immutableLabelInfo);
109            }
110            case UNIFORM:
111            case STRATIFIED: {
112                ImmutableOutputInfo<Label> immutableLabelInfo = examples.getOutputIDInfo();
113                return new DummyClassifierModel(provenance, featureMap, immutableLabelInfo, dummyType, seed);
114            }
115            default:
116                throw new IllegalStateException("Unknown dummyType " + dummyType);
117        }
118    }
119
120    @Override
121    public int getInvocationCount() {
122        return invocationCount;
123    }
124
125    @Override
126    public synchronized void setInvocationCount(int invocationCount){
127        if(invocationCount < 0){
128            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
129        }
130
131        this.invocationCount = invocationCount;
132    }
133
134    @Override
135    public String toString() {
136        switch (dummyType) {
137            case CONSTANT:
138                return "DummyClassifierTrainer(dummyType="+dummyType+",constantLabel="+constantLabel+")";
139            case MOST_FREQUENT: {
140                return "DummyClassifierTrainer(dummyType="+dummyType+")";
141            }
142            case UNIFORM:
143            case STRATIFIED: {
144                return "DummyClassifierTrainer(dummyType="+dummyType+",seed="+seed+")";
145            }
146            default:
147                return "DummyClassifierTrainer(dummyType="+dummyType+")";
148        }
149    }
150
151    @Override
152    public TrainerProvenance getProvenance() {
153        return new TrainerProvenanceImpl(this);
154    }
155
156    /**
157     * Creates a trainer which creates models which return random labels sampled from the training label distribution.
158     * @param seed The RNG seed to use.
159     * @return A classification trainer.
160     */
161    public static DummyClassifierTrainer createStratifiedTrainer(long seed) {
162        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
163        trainer.dummyType = DummyType.STRATIFIED;
164        trainer.seed = seed;
165        return trainer;
166    }
167
168    /**
169     * Creates a trainer which creates models which return a fixed label.
170     * @param constantLabel The label to return.
171     * @return A classification trainer.
172     */
173    public static DummyClassifierTrainer createConstantTrainer(String constantLabel) {
174        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
175        trainer.dummyType = DummyType.CONSTANT;
176        trainer.constantLabel = constantLabel;
177        return trainer;
178    }
179
180    /**
181     * Creates a trainer which creates models which return random labels sampled uniformly from the labels seen at training time.
182     * @param seed The RNG seed to use.
183     * @return A classification trainer.
184     */
185    public static DummyClassifierTrainer createUniformTrainer(long seed) {
186        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
187        trainer.dummyType = DummyType.UNIFORM;
188        trainer.seed = seed;
189        return trainer;
190    }
191
192    /**
193     * Creates a trainer which creates models which return a fixed label, the one which was most frequent in the training data.
194     * @return A classification trainer.
195     */
196    public static DummyClassifierTrainer createMostFrequentTrainer() {
197        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
198        trainer.dummyType = DummyType.MOST_FREQUENT;
199        return trainer;
200    }
201}