001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.example;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.MutableDataset;
023import org.tribuo.classification.Label;
024import org.tribuo.classification.LabelFactory;
025import org.tribuo.impl.ArrayExample;
026import org.tribuo.provenance.DataSourceProvenance;
027import org.tribuo.provenance.SimpleDataSourceProvenance;
028
029import java.time.OffsetDateTime;
030
031/**
032 * Generates three example train and test datasets, used for unit testing.
033 * They don't necessarily have sensible classification boundaries,
034 * it's for testing the machinery rather than accuracy.
035 */
036public final class LabelledDataGenerator {
037
038    // final class with private constructor to ensure it's not instantiated.
039    private LabelledDataGenerator() {}
040
041    private static final LabelFactory labelFactory = new LabelFactory();
042
043    /**
044     * Generates a train/test dataset pair which is dense in the features,
045     * each example has 4 features,{A,B,C,D}, and there are 4 classes,
046     * {Foo,Bar,Baz,Quux}.
047     * @return A pair of datasets.
048     */
049    public static Pair<Dataset<Label>,Dataset<Label>> denseTrainTest() {
050        return denseTrainTest(-1.0);
051    }
052
053    /**
054     * Generates a train/test dataset pair which is dense in the features,
055     * each example has 4 features,{A,B,C,D}, and there are 4 classes,
056     * {Foo,Bar,Baz,Quux}.
057     * @param negate Supply -1.0 to insert some negative values into the dataset.
058     * @return A pair of datasets.
059     */
060    public static Pair<Dataset<Label>,Dataset<Label>> denseTrainTest(double negate) {
061        DataSourceProvenance provenance = new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(),labelFactory);
062        MutableDataset<Label> train = new MutableDataset<>(provenance,labelFactory);
063
064        String[] names = new String[]{"A","B","C","D"};
065        double[] values = new double[]{1.0,0.5,1.0,negate*1.0};
066        train.add(new ArrayExample<>(new Label("Foo"),names,values));
067        values = new double[]{1.5,0.35,1.3,negate*1.2};
068        train.add(new ArrayExample<>(new Label("Foo"),names.clone(),values));
069        values = new double[]{1.2,0.45,1.5,negate*1.0};
070        train.add(new ArrayExample<>(new Label("Foo"),names.clone(),values));
071
072        values = new double[]{negate*1.1,0.55,negate*1.5,0.5};
073        train.add(new ArrayExample<>(new Label("Bar"),names.clone(),values));
074        values = new double[]{negate*1.5,0.25,negate*1,0.125};
075        train.add(new ArrayExample<>(new Label("Bar"),names.clone(),values));
076        values = new double[]{negate*1,0.5,negate*1.123,0.123};
077        train.add(new ArrayExample<>(new Label("Bar"),names.clone(),values));
078
079        values = new double[]{1.5,5.0,0.5,4.5};
080        train.add(new ArrayExample<>(new Label("Baz"),names.clone(),values));
081        values = new double[]{1.234,5.1235,0.1235,6.0};
082        train.add(new ArrayExample<>(new Label("Baz"),names.clone(),values));
083        values = new double[]{1.734,4.5,0.5123,5.5};
084        train.add(new ArrayExample<>(new Label("Baz"),names.clone(),values));
085
086        values = new double[]{negate*1,0.25,5,10.0};
087        train.add(new ArrayExample<>(new Label("Quux"),names.clone(),values));
088        values = new double[]{negate*1.4,0.55,5.65,12.0};
089        train.add(new ArrayExample<>(new Label("Quux"),names.clone(),values));
090        values = new double[]{negate*1.9,0.25,5.9,15};
091        train.add(new ArrayExample<>(new Label("Quux"),names.clone(),values));
092
093        DataSourceProvenance testProvenance = new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(),labelFactory);
094        MutableDataset<Label> test = new MutableDataset<>(testProvenance,labelFactory);
095
096        values = new double[]{2.0,0.45,3.5,negate*2.0};
097        test.add(new ArrayExample<>(new Label("Foo"),names.clone(),values));
098        values = new double[]{negate*2.0,0.55,negate*2.5,2.5};
099        test.add(new ArrayExample<>(new Label("Bar"),names.clone(),values));
100        values = new double[]{1.75,5.0,1.0,6.5};
101        test.add(new ArrayExample<>(new Label("Baz"),names.clone(),values));
102        values = new double[]{negate*1.5,0.25,5.0,20.0};
103        test.add(new ArrayExample<>(new Label("Quux"),names.clone(),values));
104
105        return new Pair<>(train,test);
106    }
107
108    /**
109     * Generates a pair of datasets, where the features are sparse,
110     * and unknown features appear in the test data. It has the same
111     * 4 classes {Foo,Bar,Baz,Quux}.
112     * @return A pair of train and test datasets.
113     */
114    public static Pair<Dataset<Label>,Dataset<Label>> sparseTrainTest() {
115        return sparseTrainTest(-1.0);
116    }
117
118    /**
119     * Generates a pair of datasets, where the features are sparse,
120     * and unknown features appear in the test data. It has the same
121     * 4 classes {Foo,Bar,Baz,Quux}.
122     * @param negate Supply -1.0 to negate some values in this dataset.
123     * @return A pair of train and test datasets.
124     */
125    public static Pair<Dataset<Label>,Dataset<Label>> sparseTrainTest(double negate) {
126        DataSourceProvenance provenance = new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(),labelFactory);
127        MutableDataset<Label> train = new MutableDataset<>(provenance,labelFactory);
128
129        String[] names = new String[]{"A","B","C","D"};
130        double[] values = new double[]{1.0,0.5,1.0,negate*1.0};
131        train.add(new ArrayExample<>(new Label("Foo"),names,values));
132        names = new String[]{"B","D","F","H"};
133        values = new double[]{1.5,0.35,1.3,negate*1.2};
134        train.add(new ArrayExample<>(new Label("Foo"),names,values));
135        names = new String[]{"A","J","D","M"};
136        values = new double[]{1.2,0.45,1.5,negate*1.0};
137        train.add(new ArrayExample<>(new Label("Foo"),names,values));
138
139        names = new String[]{"C","E","F","H"};
140        values = new double[]{negate*1.1,0.55,negate*1.5,0.5};
141        train.add(new ArrayExample<>(new Label("Bar"),names,values));
142        names = new String[]{"E","G","F","I"};
143        values = new double[]{negate*1.5,0.25,negate*1,0.125};
144        train.add(new ArrayExample<>(new Label("Bar"),names,values));
145        names = new String[]{"J","K","C","E"};
146        values = new double[]{negate*1,0.5,negate*1.123,0.123};
147        train.add(new ArrayExample<>(new Label("Bar"),names,values));
148
149        names = new String[]{"E","A","K","J"};
150        values = new double[]{1.5,5.0,0.5,4.5};
151        train.add(new ArrayExample<>(new Label("Baz"),names,values));
152        names = new String[]{"B","C","E","H"};
153        values = new double[]{1.234,5.1235,0.1235,6.0};
154        train.add(new ArrayExample<>(new Label("Baz"),names,values));
155        names = new String[]{"A","M","I","J"};
156        values = new double[]{1.734,4.5,0.5123,5.5};
157        train.add(new ArrayExample<>(new Label("Baz"),names,values));
158
159        names = new String[]{"Z","A","B","C"};
160        values = new double[]{negate*1,0.25,5,10.0};
161        train.add(new ArrayExample<>(new Label("Quux"),names,values));
162        names = new String[]{"K","V","E","D"};
163        values = new double[]{negate*1.4,0.55,5.65,12.0};
164        train.add(new ArrayExample<>(new Label("Quux"),names,values));
165        names = new String[]{"B","G","E","A"};
166        values = new double[]{negate*1.9,0.25,5.9,15};
167        train.add(new ArrayExample<>(new Label("Quux"),names,values));
168
169        DataSourceProvenance testProvenance = new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(),labelFactory);
170        MutableDataset<Label> test = new MutableDataset<>(testProvenance,labelFactory);
171
172        names = new String[]{"AA","B","C","D"};
173        values = new double[]{2.0,0.45,3.5,negate*2.0};
174        test.add(new ArrayExample<>(new Label("Foo"),names,values));
175        names = new String[]{"B","BB","F","E"};
176        values = new double[]{negate*2.0,0.55,negate*2.5,2.5};
177        test.add(new ArrayExample<>(new Label("Bar"),names,values));
178        names = new String[]{"B","E","G","H"};
179        values = new double[]{1.75,5.0,1.0,6.5};
180        test.add(new ArrayExample<>(new Label("Baz"),names,values));
181        names = new String[]{"B","CC","DD","EE"};
182        values = new double[]{negate*1.5,0.25,5.0,20.0};
183        test.add(new ArrayExample<>(new Label("Quux"),names,values));
184
185        return new Pair<>(train,test);
186    }
187
188    /**
189     * Generates a pair of datasets with sparse features and unknown features
190     * in the test data. Has binary labels {Foo,Bar}.
191     * @return A pair of train and test datasets.
192     */
193    public static Pair<Dataset<Label>,Dataset<Label>> binarySparseTrainTest() {
194        return binarySparseTrainTest(-1.0);
195    }
196
197    /**
198     * Generates a pair of datasets with sparse features and unknown features
199     * in the test data. Has binary labels {Foo,Bar}.
200     * @param negate Supply -1.0 to negate some values in this dataset.
201     * @return A pair of train and test datasets.
202     */
203    public static Pair<Dataset<Label>,Dataset<Label>> binarySparseTrainTest(double negate) {
204        DataSourceProvenance provenance = new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(),labelFactory);
205        MutableDataset<Label> train = new MutableDataset<>(provenance,labelFactory);
206
207        String[] names = new String[]{"A","B","C","D"};
208        double[] values = new double[]{1.0,0.5,1.0,negate*1.0};
209        train.add(new ArrayExample<>(new Label("Foo"),names,values));
210        names = new String[]{"B","D","F","H"};
211        values = new double[]{1.5,0.35,1.3,negate*1.2};
212        train.add(new ArrayExample<>(new Label("Foo"),names,values));
213        names = new String[]{"A","J","D","M"};
214        values = new double[]{1.2,0.45,1.5,negate*1.0};
215        train.add(new ArrayExample<>(new Label("Foo"),names,values));
216
217        names = new String[]{"C","E","F","H"};
218        values = new double[]{negate*1.1,0.55,negate*1.5,0.5};
219        train.add(new ArrayExample<>(new Label("Bar"),names,values));
220        names = new String[]{"E","G","F","I"};
221        values = new double[]{negate*1.5,0.25,negate*1,0.125};
222        train.add(new ArrayExample<>(new Label("Bar"),names,values));
223        names = new String[]{"J","K","C","E"};
224        values = new double[]{negate*1,0.5,negate*1.123,0.123};
225        train.add(new ArrayExample<>(new Label("Bar"),names,values));
226
227        names = new String[]{"E","A","K","J"};
228        values = new double[]{1.5,5.0,0.5,4.5};
229        train.add(new ArrayExample<>(new Label("Foo"),names,values));
230        names = new String[]{"B","C","E","H"};
231        values = new double[]{1.234,5.1235,0.1235,6.0};
232        train.add(new ArrayExample<>(new Label("Foo"),names,values));
233        names = new String[]{"A","M","I","J"};
234        values = new double[]{1.734,4.5,0.5123,5.5};
235        train.add(new ArrayExample<>(new Label("Foo"),names,values));
236
237        names = new String[]{"Z","A","B","C"};
238        values = new double[]{negate*1,0.25,5,10.0};
239        train.add(new ArrayExample<>(new Label("Bar"),names,values));
240        names = new String[]{"K","V","E","D"};
241        values = new double[]{negate*1.4,0.55,5.65,12.0};
242        train.add(new ArrayExample<>(new Label("Bar"),names,values));
243        names = new String[]{"B","G","E","A"};
244        values = new double[]{negate*1.9,0.25,5.9,15};
245        train.add(new ArrayExample<>(new Label("Bar"),names,values));
246
247        DataSourceProvenance testProvenance = new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(),labelFactory);
248        MutableDataset<Label> test = new MutableDataset<>(testProvenance,labelFactory);
249
250        names = new String[]{"AA","B","C","D"};
251        values = new double[]{2.0,0.45,3.5,negate*2.0};
252        test.add(new ArrayExample<>(new Label("Foo"),names,values));
253        names = new String[]{"B","BB","F","E"};
254        values = new double[]{negate*2.0,0.55,negate*2.5,2.5};
255        test.add(new ArrayExample<>(new Label("Bar"),names,values));
256        names = new String[]{"B","E","G","H"};
257        values = new double[]{1.75,5.0,1.0,6.5};
258        test.add(new ArrayExample<>(new Label("Foo"),names,values));
259        names = new String[]{"B","CC","DD","EE"};
260        values = new double[]{negate*1.5,0.25,5.0,20.0};
261        test.add(new ArrayExample<>(new Label("Bar"),names,values));
262
263        return new Pair<>(train,test);
264    }
265
266    /**
267     * Generates an example with the feature ids 1,5,8, which does not intersect with the
268     * ids used elsewhere in this class. This should make the example empty at prediction time.
269     * @return An example with features {1:1.0,5:5.0,8:8.0}.
270     */
271    public static Example<Label> invalidSparseExample() {
272        return new ArrayExample<>(new Label("Foo"),new String[]{"1","5","8"},new double[]{1.0,5.0,8.0});
273    }
274
275    /**
276     * Generates an example with no features.
277     * @return An example with no features.
278     */
279    public static Example<Label> emptyExample() {
280        return new ArrayExample<>(new Label("Foo"),new String[]{},new double[]{});
281    }
282
283}