001/*
002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.example;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
025import org.tribuo.ConfigurableDataSource;
026import org.tribuo.Example;
027import org.tribuo.classification.Label;
028import org.tribuo.classification.LabelFactory;
029import org.tribuo.provenance.ConfiguredDataSourceProvenance;
030import org.tribuo.provenance.DataSourceProvenance;
031
032import java.util.Collections;
033import java.util.HashMap;
034import java.util.Iterator;
035import java.util.List;
036import java.util.Map;
037import java.util.Random;
038
039/**
040 * The base class for the 2d binary classification data sources in {@link org.tribuo.classification.example}.
041 * <p>
042 * The feature names are {@link #X1} and {@link #X2} and the labels are {@link #FIRST_CLASS} and {@link #SECOND_CLASS}.
043 * <p>
044 * Likely to be sealed to the classes in this package when we adopt Java 17.
045 */
046public abstract class DemoLabelDataSource implements ConfigurableDataSource<Label> {
047
048    protected static final LabelFactory factory = new LabelFactory();
049
050    /**
051     * The first feature name.
052     */
053    public static final String X1 = "X1";
054    /**
055     * The second feature name.
056     */
057    public static final String X2 = "X2";
058
059    /**
060     * The feature names array.
061     */
062    static final String[] FEATURE_NAMES = new String[]{X1,X2};
063
064    /**
065     * The first class.
066     */
067    public static final Label FIRST_CLASS = new Label("X");
068    /**
069     * The second class.
070     */
071    public static final Label SECOND_CLASS = new Label("O");
072
073    @Config(mandatory = true, description = "Number of samples to generate.")
074    protected int numSamples;
075
076    @Config(mandatory = true, description = "RNG seed.")
077    protected long seed;
078
079    // Uses java.util.Random as SplittableRandom is missing nextGaussian in versions before 17.
080    protected Random rng;
081
082    protected List<Example<Label>> examples;
083
084    /**
085     * For OLCUT.
086     */
087    DemoLabelDataSource() {}
088
089    /**
090     * Stores the numSamples and the seed.
091     * <p>
092     * Note does not call {@link #postConfig} to generate the examples,
093     * this must be called by the subclass's constructor.
094     * @param numSamples The number of samples to generate.
095     * @param seed The RNG seed.
096     */
097    DemoLabelDataSource(int numSamples, long seed) {
098        this.numSamples = numSamples;
099        this.seed = seed;
100    }
101
102    /**
103     * Configures the class. Should be called in sub-classes' postConfigs
104     * after they've validated their parameters.
105     */
106    @Override
107    public void postConfig() {
108        if (numSamples < 1) {
109            throw new PropertyException("","numSamples","Number of samples must be positive, found " + numSamples);
110        }
111        this.rng = new Random(seed);
112        this.examples = Collections.unmodifiableList(generate());
113    }
114
115    /**
116     * Generates the examples using the configured fields.
117     * <p>
118     * Is called internally by {@link #postConfig}.
119     * @return The generated examples.
120     */
121    protected abstract List<Example<Label>> generate();
122
123    @Override
124    public LabelFactory getOutputFactory() {
125        return factory;
126    }
127
128    @Override
129    public DataSourceProvenance getProvenance() {
130        return new DemoLabelDataSourceProvenance(this);
131    }
132
133    @Override
134    public Iterator<Example<Label>> iterator() {
135        return examples.iterator();
136    }
137
138    /**
139     * Provenance for {@link DemoLabelDataSource}.
140     */
141    public static final class DemoLabelDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
142        private static final long serialVersionUID = 1L;
143
144        /**
145         * Constructs a provenance from the host data source.
146         *
147         * @param host The host to read.
148         */
149        DemoLabelDataSourceProvenance(DemoLabelDataSource host) {
150            super(host, "DataSource");
151        }
152
153        /**
154         * Constructs a provenance from the marshalled form.
155         *
156         * @param map The map of field values.
157         */
158        public DemoLabelDataSourceProvenance(Map<String, Provenance> map) {
159            this(extractProvenanceInfo(map));
160        }
161
162        private DemoLabelDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) {
163            super(info);
164        }
165
166        /**
167         * Extracts the relevant provenance information fields for this class.
168         *
169         * @param map The map to remove values from.
170         * @return The extracted information.
171         */
172        static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
173            Map<String, Provenance> configuredParameters = new HashMap<>(map);
174            String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters, CLASS_NAME, StringProvenance.class, DemoLabelDataSourceProvenance.class.getSimpleName()).getValue();
175            String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, DemoLabelDataSourceProvenance.class.getSimpleName()).getValue();
176
177            return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
178        }
179    }
180}