001/* 002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.example; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 025import org.tribuo.ConfigurableDataSource; 026import org.tribuo.Example; 027import org.tribuo.classification.Label; 028import org.tribuo.classification.LabelFactory; 029import org.tribuo.provenance.ConfiguredDataSourceProvenance; 030import org.tribuo.provenance.DataSourceProvenance; 031 032import java.util.Collections; 033import java.util.HashMap; 034import java.util.Iterator; 035import java.util.List; 036import java.util.Map; 037import java.util.Random; 038 039/** 040 * The base class for the 2d binary classification data sources in {@link org.tribuo.classification.example}. 041 * <p> 042 * The feature names are {@link #X1} and {@link #X2} and the labels are {@link #FIRST_CLASS} and {@link #SECOND_CLASS}. 043 * <p> 044 * Likely to be sealed to the classes in this package when we adopt Java 17. 045 */ 046public abstract class DemoLabelDataSource implements ConfigurableDataSource<Label> { 047 048 protected static final LabelFactory factory = new LabelFactory(); 049 050 /** 051 * The first feature name. 052 */ 053 public static final String X1 = "X1"; 054 /** 055 * The second feature name. 056 */ 057 public static final String X2 = "X2"; 058 059 /** 060 * The feature names array. 061 */ 062 static final String[] FEATURE_NAMES = new String[]{X1,X2}; 063 064 /** 065 * The first class. 066 */ 067 public static final Label FIRST_CLASS = new Label("X"); 068 /** 069 * The second class. 070 */ 071 public static final Label SECOND_CLASS = new Label("O"); 072 073 @Config(mandatory = true, description = "Number of samples to generate.") 074 protected int numSamples; 075 076 @Config(mandatory = true, description = "RNG seed.") 077 protected long seed; 078 079 // Uses java.util.Random as SplittableRandom is missing nextGaussian in versions before 17. 080 protected Random rng; 081 082 protected List<Example<Label>> examples; 083 084 /** 085 * For OLCUT. 086 */ 087 DemoLabelDataSource() {} 088 089 /** 090 * Stores the numSamples and the seed. 091 * <p> 092 * Note does not call {@link #postConfig} to generate the examples, 093 * this must be called by the subclass's constructor. 094 * @param numSamples The number of samples to generate. 095 * @param seed The RNG seed. 096 */ 097 DemoLabelDataSource(int numSamples, long seed) { 098 this.numSamples = numSamples; 099 this.seed = seed; 100 } 101 102 /** 103 * Configures the class. Should be called in sub-classes' postConfigs 104 * after they've validated their parameters. 105 */ 106 @Override 107 public void postConfig() { 108 if (numSamples < 1) { 109 throw new PropertyException("","numSamples","Number of samples must be positive, found " + numSamples); 110 } 111 this.rng = new Random(seed); 112 this.examples = Collections.unmodifiableList(generate()); 113 } 114 115 /** 116 * Generates the examples using the configured fields. 117 * <p> 118 * Is called internally by {@link #postConfig}. 119 * @return The generated examples. 120 */ 121 protected abstract List<Example<Label>> generate(); 122 123 @Override 124 public LabelFactory getOutputFactory() { 125 return factory; 126 } 127 128 @Override 129 public DataSourceProvenance getProvenance() { 130 return new DemoLabelDataSourceProvenance(this); 131 } 132 133 @Override 134 public Iterator<Example<Label>> iterator() { 135 return examples.iterator(); 136 } 137 138 /** 139 * Provenance for {@link DemoLabelDataSource}. 140 */ 141 public static final class DemoLabelDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance { 142 private static final long serialVersionUID = 1L; 143 144 /** 145 * Constructs a provenance from the host data source. 146 * 147 * @param host The host to read. 148 */ 149 DemoLabelDataSourceProvenance(DemoLabelDataSource host) { 150 super(host, "DataSource"); 151 } 152 153 /** 154 * Constructs a provenance from the marshalled form. 155 * 156 * @param map The map of field values. 157 */ 158 public DemoLabelDataSourceProvenance(Map<String, Provenance> map) { 159 this(extractProvenanceInfo(map)); 160 } 161 162 private DemoLabelDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) { 163 super(info); 164 } 165 166 /** 167 * Extracts the relevant provenance information fields for this class. 168 * 169 * @param map The map to remove values from. 170 * @return The extracted information. 171 */ 172 static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) { 173 Map<String, Provenance> configuredParameters = new HashMap<>(map); 174 String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters, CLASS_NAME, StringProvenance.class, DemoLabelDataSourceProvenance.class.getSimpleName()).getValue(); 175 String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, DemoLabelDataSourceProvenance.class.getSimpleName()).getValue(); 176 177 return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap()); 178 } 179 } 180}