001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.example; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.MutableDataset; 023import org.tribuo.classification.Label; 024import org.tribuo.classification.LabelFactory; 025import org.tribuo.impl.ArrayExample; 026import org.tribuo.provenance.DataSourceProvenance; 027import org.tribuo.provenance.SimpleDataSourceProvenance; 028 029import java.time.OffsetDateTime; 030 031/** 032 * Generates three example train and test datasets, used for unit testing. 033 * They don't necessarily have sensible classification boundaries, 034 * it's for testing the machinery rather than accuracy. 035 */ 036public final class LabelledDataGenerator { 037 038 // final class with private constructor to ensure it's not instantiated. 039 private LabelledDataGenerator() {} 040 041 private static final LabelFactory labelFactory = new LabelFactory(); 042 043 /** 044 * Generates a train/test dataset pair which is dense in the features, 045 * each example has 4 features,{A,B,C,D}, and there are 4 classes, 046 * {Foo,Bar,Baz,Quux}. 047 * @return A pair of datasets. 048 */ 049 public static Pair<Dataset<Label>,Dataset<Label>> denseTrainTest() { 050 return denseTrainTest(-1.0); 051 } 052 053 /** 054 * Generates a train/test dataset pair which is dense in the features, 055 * each example has 4 features,{A,B,C,D}, and there are 4 classes, 056 * {Foo,Bar,Baz,Quux}. 057 * @param negate Supply -1.0 to insert some negative values into the dataset. 058 * @return A pair of datasets. 059 */ 060 public static Pair<Dataset<Label>,Dataset<Label>> denseTrainTest(double negate) { 061 DataSourceProvenance provenance = new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(),labelFactory); 062 MutableDataset<Label> train = new MutableDataset<>(provenance,labelFactory); 063 064 String[] names = new String[]{"A","B","C","D"}; 065 double[] values = new double[]{1.0,0.5,1.0,negate*1.0}; 066 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 067 values = new double[]{1.5,0.35,1.3,negate*1.2}; 068 train.add(new ArrayExample<>(new Label("Foo"),names.clone(),values)); 069 values = new double[]{1.2,0.45,1.5,negate*1.0}; 070 train.add(new ArrayExample<>(new Label("Foo"),names.clone(),values)); 071 072 values = new double[]{negate*1.1,0.55,negate*1.5,0.5}; 073 train.add(new ArrayExample<>(new Label("Bar"),names.clone(),values)); 074 values = new double[]{negate*1.5,0.25,negate*1,0.125}; 075 train.add(new ArrayExample<>(new Label("Bar"),names.clone(),values)); 076 values = new double[]{negate*1,0.5,negate*1.123,0.123}; 077 train.add(new ArrayExample<>(new Label("Bar"),names.clone(),values)); 078 079 values = new double[]{1.5,5.0,0.5,4.5}; 080 train.add(new ArrayExample<>(new Label("Baz"),names.clone(),values)); 081 values = new double[]{1.234,5.1235,0.1235,6.0}; 082 train.add(new ArrayExample<>(new Label("Baz"),names.clone(),values)); 083 values = new double[]{1.734,4.5,0.5123,5.5}; 084 train.add(new ArrayExample<>(new Label("Baz"),names.clone(),values)); 085 086 values = new double[]{negate*1,0.25,5,10.0}; 087 train.add(new ArrayExample<>(new Label("Quux"),names.clone(),values)); 088 values = new double[]{negate*1.4,0.55,5.65,12.0}; 089 train.add(new ArrayExample<>(new Label("Quux"),names.clone(),values)); 090 values = new double[]{negate*1.9,0.25,5.9,15}; 091 train.add(new ArrayExample<>(new Label("Quux"),names.clone(),values)); 092 093 DataSourceProvenance testProvenance = new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(),labelFactory); 094 MutableDataset<Label> test = new MutableDataset<>(testProvenance,labelFactory); 095 096 values = new double[]{2.0,0.45,3.5,negate*2.0}; 097 test.add(new ArrayExample<>(new Label("Foo"),names.clone(),values)); 098 values = new double[]{negate*2.0,0.55,negate*2.5,2.5}; 099 test.add(new ArrayExample<>(new Label("Bar"),names.clone(),values)); 100 values = new double[]{1.75,5.0,1.0,6.5}; 101 test.add(new ArrayExample<>(new Label("Baz"),names.clone(),values)); 102 values = new double[]{negate*1.5,0.25,5.0,20.0}; 103 test.add(new ArrayExample<>(new Label("Quux"),names.clone(),values)); 104 105 return new Pair<>(train,test); 106 } 107 108 /** 109 * Generates a pair of datasets, where the features are sparse, 110 * and unknown features appear in the test data. It has the same 111 * 4 classes {Foo,Bar,Baz,Quux}. 112 * @return A pair of train and test datasets. 113 */ 114 public static Pair<Dataset<Label>,Dataset<Label>> sparseTrainTest() { 115 return sparseTrainTest(-1.0); 116 } 117 118 /** 119 * Generates a pair of datasets, where the features are sparse, 120 * and unknown features appear in the test data. It has the same 121 * 4 classes {Foo,Bar,Baz,Quux}. 122 * @param negate Supply -1.0 to negate some values in this dataset. 123 * @return A pair of train and test datasets. 124 */ 125 public static Pair<Dataset<Label>,Dataset<Label>> sparseTrainTest(double negate) { 126 DataSourceProvenance provenance = new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(),labelFactory); 127 MutableDataset<Label> train = new MutableDataset<>(provenance,labelFactory); 128 129 String[] names = new String[]{"A","B","C","D"}; 130 double[] values = new double[]{1.0,0.5,1.0,negate*1.0}; 131 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 132 names = new String[]{"B","D","F","H"}; 133 values = new double[]{1.5,0.35,1.3,negate*1.2}; 134 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 135 names = new String[]{"A","J","D","M"}; 136 values = new double[]{1.2,0.45,1.5,negate*1.0}; 137 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 138 139 names = new String[]{"C","E","F","H"}; 140 values = new double[]{negate*1.1,0.55,negate*1.5,0.5}; 141 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 142 names = new String[]{"E","G","F","I"}; 143 values = new double[]{negate*1.5,0.25,negate*1,0.125}; 144 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 145 names = new String[]{"J","K","C","E"}; 146 values = new double[]{negate*1,0.5,negate*1.123,0.123}; 147 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 148 149 names = new String[]{"E","A","K","J"}; 150 values = new double[]{1.5,5.0,0.5,4.5}; 151 train.add(new ArrayExample<>(new Label("Baz"),names,values)); 152 names = new String[]{"B","C","E","H"}; 153 values = new double[]{1.234,5.1235,0.1235,6.0}; 154 train.add(new ArrayExample<>(new Label("Baz"),names,values)); 155 names = new String[]{"A","M","I","J"}; 156 values = new double[]{1.734,4.5,0.5123,5.5}; 157 train.add(new ArrayExample<>(new Label("Baz"),names,values)); 158 159 names = new String[]{"Z","A","B","C"}; 160 values = new double[]{negate*1,0.25,5,10.0}; 161 train.add(new ArrayExample<>(new Label("Quux"),names,values)); 162 names = new String[]{"K","V","E","D"}; 163 values = new double[]{negate*1.4,0.55,5.65,12.0}; 164 train.add(new ArrayExample<>(new Label("Quux"),names,values)); 165 names = new String[]{"B","G","E","A"}; 166 values = new double[]{negate*1.9,0.25,5.9,15}; 167 train.add(new ArrayExample<>(new Label("Quux"),names,values)); 168 169 DataSourceProvenance testProvenance = new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(),labelFactory); 170 MutableDataset<Label> test = new MutableDataset<>(testProvenance,labelFactory); 171 172 names = new String[]{"AA","B","C","D"}; 173 values = new double[]{2.0,0.45,3.5,negate*2.0}; 174 test.add(new ArrayExample<>(new Label("Foo"),names,values)); 175 names = new String[]{"B","BB","F","E"}; 176 values = new double[]{negate*2.0,0.55,negate*2.5,2.5}; 177 test.add(new ArrayExample<>(new Label("Bar"),names,values)); 178 names = new String[]{"B","E","G","H"}; 179 values = new double[]{1.75,5.0,1.0,6.5}; 180 test.add(new ArrayExample<>(new Label("Baz"),names,values)); 181 names = new String[]{"B","CC","DD","EE"}; 182 values = new double[]{negate*1.5,0.25,5.0,20.0}; 183 test.add(new ArrayExample<>(new Label("Quux"),names,values)); 184 185 return new Pair<>(train,test); 186 } 187 188 /** 189 * Generates a pair of datasets with sparse features and unknown features 190 * in the test data. Has binary labels {Foo,Bar}. 191 * @return A pair of train and test datasets. 192 */ 193 public static Pair<Dataset<Label>,Dataset<Label>> binarySparseTrainTest() { 194 return binarySparseTrainTest(-1.0); 195 } 196 197 /** 198 * Generates a pair of datasets with sparse features and unknown features 199 * in the test data. Has binary labels {Foo,Bar}. 200 * @param negate Supply -1.0 to negate some values in this dataset. 201 * @return A pair of train and test datasets. 202 */ 203 public static Pair<Dataset<Label>,Dataset<Label>> binarySparseTrainTest(double negate) { 204 DataSourceProvenance provenance = new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(),labelFactory); 205 MutableDataset<Label> train = new MutableDataset<>(provenance,labelFactory); 206 207 String[] names = new String[]{"A","B","C","D"}; 208 double[] values = new double[]{1.0,0.5,1.0,negate*1.0}; 209 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 210 names = new String[]{"B","D","F","H"}; 211 values = new double[]{1.5,0.35,1.3,negate*1.2}; 212 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 213 names = new String[]{"A","J","D","M"}; 214 values = new double[]{1.2,0.45,1.5,negate*1.0}; 215 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 216 217 names = new String[]{"C","E","F","H"}; 218 values = new double[]{negate*1.1,0.55,negate*1.5,0.5}; 219 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 220 names = new String[]{"E","G","F","I"}; 221 values = new double[]{negate*1.5,0.25,negate*1,0.125}; 222 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 223 names = new String[]{"J","K","C","E"}; 224 values = new double[]{negate*1,0.5,negate*1.123,0.123}; 225 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 226 227 names = new String[]{"E","A","K","J"}; 228 values = new double[]{1.5,5.0,0.5,4.5}; 229 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 230 names = new String[]{"B","C","E","H"}; 231 values = new double[]{1.234,5.1235,0.1235,6.0}; 232 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 233 names = new String[]{"A","M","I","J"}; 234 values = new double[]{1.734,4.5,0.5123,5.5}; 235 train.add(new ArrayExample<>(new Label("Foo"),names,values)); 236 237 names = new String[]{"Z","A","B","C"}; 238 values = new double[]{negate*1,0.25,5,10.0}; 239 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 240 names = new String[]{"K","V","E","D"}; 241 values = new double[]{negate*1.4,0.55,5.65,12.0}; 242 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 243 names = new String[]{"B","G","E","A"}; 244 values = new double[]{negate*1.9,0.25,5.9,15}; 245 train.add(new ArrayExample<>(new Label("Bar"),names,values)); 246 247 DataSourceProvenance testProvenance = new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(),labelFactory); 248 MutableDataset<Label> test = new MutableDataset<>(testProvenance,labelFactory); 249 250 names = new String[]{"AA","B","C","D"}; 251 values = new double[]{2.0,0.45,3.5,negate*2.0}; 252 test.add(new ArrayExample<>(new Label("Foo"),names,values)); 253 names = new String[]{"B","BB","F","E"}; 254 values = new double[]{negate*2.0,0.55,negate*2.5,2.5}; 255 test.add(new ArrayExample<>(new Label("Bar"),names,values)); 256 names = new String[]{"B","E","G","H"}; 257 values = new double[]{1.75,5.0,1.0,6.5}; 258 test.add(new ArrayExample<>(new Label("Foo"),names,values)); 259 names = new String[]{"B","CC","DD","EE"}; 260 values = new double[]{negate*1.5,0.25,5.0,20.0}; 261 test.add(new ArrayExample<>(new Label("Bar"),names,values)); 262 263 return new Pair<>(train,test); 264 } 265 266 /** 267 * Generates an example with the feature ids 1,5,8, which does not intersect with the 268 * ids used elsewhere in this class. This should make the example empty at prediction time. 269 * @return An example with features {1:1.0,5:5.0,8:8.0}. 270 */ 271 public static Example<Label> invalidSparseExample() { 272 return new ArrayExample<>(new Label("Foo"),new String[]{"1","5","8"},new double[]{1.0,5.0,8.0}); 273 } 274 275 /** 276 * Generates an example with no features. 277 * @return An example with no features. 278 */ 279 public static Example<Label> emptyExample() { 280 return new ArrayExample<>(new Label("Foo"),new String[]{},new double[]{}); 281 } 282 283}