001/* 002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.baseline; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import org.tribuo.Dataset; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.MutableOutputInfo; 027import org.tribuo.Trainer; 028import org.tribuo.classification.Label; 029import org.tribuo.provenance.ModelProvenance; 030import org.tribuo.provenance.TrainerProvenance; 031import org.tribuo.provenance.impl.TrainerProvenanceImpl; 032 033import java.time.OffsetDateTime; 034import java.util.Map; 035 036/** 037 * A trainer for simple baseline classifiers. Use this only for comparison purposes, if you can't beat these 038 * baselines, your ML system doesn't work. 039 */ 040public final class DummyClassifierTrainer implements Trainer<Label> { 041 042 /** 043 * Types of dummy classifier. 044 */ 045 public enum DummyType { 046 /** 047 * Samples the label proprotional to the training label frequencies. 048 */ 049 STRATIFIED, 050 /** 051 * Returns the most frequent training label. 052 */ 053 MOST_FREQUENT, 054 /** 055 * Samples uniformly from the label domain. 056 */ 057 UNIFORM, 058 /** 059 * Returns the supplied label for all inputs. 060 */ 061 CONSTANT 062 } 063 064 @Config(mandatory = true,description="Type of dummy classifier.") 065 private DummyType dummyType; 066 067 @Config(description="Label to use for the constant classifier.") 068 private String constantLabel; 069 070 @Config(description="Seed for the RNG.") 071 private long seed = 1L; 072 073 private int invocationCount = 0; 074 075 private DummyClassifierTrainer() {} 076 077 /** 078 * Used by the OLCUT configuration system, and should not be called by external code. 079 */ 080 @Override 081 public void postConfig() { 082 if ((dummyType == DummyType.CONSTANT) && (constantLabel == null)) { 083 throw new PropertyException("","constantLabel","Please supply a label string when using the type CONSTANT."); 084 } 085 } 086 087 @Override 088 public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance) { 089 return train(examples, instanceProvenance, INCREMENT_INVOCATION_COUNT) ; 090 } 091 092 @Override 093 public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance, int invocationCount) { 094 if(invocationCount != INCREMENT_INVOCATION_COUNT) { 095 this.invocationCount = invocationCount; 096 } 097 ModelProvenance provenance = new ModelProvenance(DummyClassifierModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance); 098 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 099 this.invocationCount++; 100 switch (dummyType) { 101 case CONSTANT: 102 MutableOutputInfo<Label> labelInfo = examples.getOutputInfo().generateMutableOutputInfo(); 103 Label constLabel = new Label(constantLabel); 104 labelInfo.observe(constLabel); 105 return new DummyClassifierModel(provenance,featureMap,labelInfo.generateImmutableOutputInfo(),constLabel); 106 case MOST_FREQUENT: { 107 ImmutableOutputInfo<Label> immutableLabelInfo = examples.getOutputIDInfo(); 108 return new DummyClassifierModel(provenance, featureMap, immutableLabelInfo); 109 } 110 case UNIFORM: 111 case STRATIFIED: { 112 ImmutableOutputInfo<Label> immutableLabelInfo = examples.getOutputIDInfo(); 113 return new DummyClassifierModel(provenance, featureMap, immutableLabelInfo, dummyType, seed); 114 } 115 default: 116 throw new IllegalStateException("Unknown dummyType " + dummyType); 117 } 118 } 119 120 @Override 121 public int getInvocationCount() { 122 return invocationCount; 123 } 124 125 @Override 126 public synchronized void setInvocationCount(int invocationCount){ 127 if(invocationCount < 0){ 128 throw new IllegalArgumentException("The supplied invocationCount is less than zero."); 129 } 130 131 this.invocationCount = invocationCount; 132 } 133 134 @Override 135 public String toString() { 136 switch (dummyType) { 137 case CONSTANT: 138 return "DummyClassifierTrainer(dummyType="+dummyType+",constantLabel="+constantLabel+")"; 139 case MOST_FREQUENT: { 140 return "DummyClassifierTrainer(dummyType="+dummyType+")"; 141 } 142 case UNIFORM: 143 case STRATIFIED: { 144 return "DummyClassifierTrainer(dummyType="+dummyType+",seed="+seed+")"; 145 } 146 default: 147 return "DummyClassifierTrainer(dummyType="+dummyType+")"; 148 } 149 } 150 151 @Override 152 public TrainerProvenance getProvenance() { 153 return new TrainerProvenanceImpl(this); 154 } 155 156 /** 157 * Creates a trainer which creates models which return random labels sampled from the training label distribution. 158 * @param seed The RNG seed to use. 159 * @return A classification trainer. 160 */ 161 public static DummyClassifierTrainer createStratifiedTrainer(long seed) { 162 DummyClassifierTrainer trainer = new DummyClassifierTrainer(); 163 trainer.dummyType = DummyType.STRATIFIED; 164 trainer.seed = seed; 165 return trainer; 166 } 167 168 /** 169 * Creates a trainer which creates models which return a fixed label. 170 * @param constantLabel The label to return. 171 * @return A classification trainer. 172 */ 173 public static DummyClassifierTrainer createConstantTrainer(String constantLabel) { 174 DummyClassifierTrainer trainer = new DummyClassifierTrainer(); 175 trainer.dummyType = DummyType.CONSTANT; 176 trainer.constantLabel = constantLabel; 177 return trainer; 178 } 179 180 /** 181 * Creates a trainer which creates models which return random labels sampled uniformly from the labels seen at training time. 182 * @param seed The RNG seed to use. 183 * @return A classification trainer. 184 */ 185 public static DummyClassifierTrainer createUniformTrainer(long seed) { 186 DummyClassifierTrainer trainer = new DummyClassifierTrainer(); 187 trainer.dummyType = DummyType.UNIFORM; 188 trainer.seed = seed; 189 return trainer; 190 } 191 192 /** 193 * Creates a trainer which creates models which return a fixed label, the one which was most frequent in the training data. 194 * @return A classification trainer. 195 */ 196 public static DummyClassifierTrainer createMostFrequentTrainer() { 197 DummyClassifierTrainer trainer = new DummyClassifierTrainer(); 198 trainer.dummyType = DummyType.MOST_FREQUENT; 199 return trainer; 200 } 201}