001/* 002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.example; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import org.tribuo.Example; 022import org.tribuo.classification.Label; 023import org.tribuo.impl.ArrayExample; 024 025import java.util.ArrayList; 026import java.util.List; 027 028/** 029 * Creates a data source using a 2d checkerboard of alternating classes. 030 */ 031public final class CheckerboardDataSource extends DemoLabelDataSource { 032 033 @Config(description = "The number of squares on each side.") 034 private int numSquares = 5; 035 036 @Config(description = "The minimum feature value.") 037 private double min = 0.0; 038 039 @Config(description = "The maximum feature value.") 040 private double max = 10.0; 041 042 private double range; 043 044 private double tileWidth; 045 046 /** 047 * For OLCUT. 048 */ 049 private CheckerboardDataSource() { 050 super(); 051 } 052 053 /** 054 * Creates a checkboard with the required number of squares per dimension, where each feature value lies between min and max. 055 * 056 * @param numSamples The number of samples to generate. 057 * @param seed The RNG seed. 058 * @param numSquares The number of squares. 059 * @param min The minimum feature value. 060 * @param max The maximum feature value. 061 */ 062 public CheckerboardDataSource(int numSamples, long seed, int numSquares, double min, double max) { 063 super(numSamples, seed); 064 this.numSquares = numSquares; 065 this.min = min; 066 this.max = max; 067 postConfig(); 068 } 069 070 /** 071 * Used by the OLCUT configuration system, and should not be called by external code. 072 */ 073 @Override 074 public void postConfig() { 075 if (max <= min) { 076 throw new PropertyException("", "min", "min must be strictly less than max, min = " + min + ", max = " + max); 077 } 078 if (numSquares < 2) { 079 throw new PropertyException("", "numSquares", "numSquares must be 2 or greater, found " + numSquares); 080 } 081 range = Math.abs(max - min); 082 tileWidth = range / numSquares; 083 super.postConfig(); 084 } 085 086 @Override 087 protected List<Example<Label>> generate() { 088 List<Example<Label>> list = new ArrayList<>(); 089 090 for (int i = 0; i < numSamples; i++) { 091 double[] values = new double[2]; 092 values[0] = (rng.nextDouble() * range); 093 values[1] = (rng.nextDouble() * range); 094 095 int modX1 = ((int) Math.floor(values[0] / tileWidth)) % 2; 096 int modX2 = ((int) Math.floor(values[1] / tileWidth)) % 2; 097 098 Label label; 099 if (modX1 == modX2) { 100 label = FIRST_CLASS; 101 } else { 102 label = SECOND_CLASS; 103 } 104 105 // Update the minimums after computing the label so we don't have to 106 // deal with tricky negative issues interacting with Math.floor(). 107 values[0] += min; 108 values[1] += min; 109 110 list.add(new ArrayExample<>(label, FEATURE_NAMES, values)); 111 } 112 113 return list; 114 } 115 116 @Override 117 public String toString() { 118 return "Checkerboard(numSamples=" + numSamples + ",seed=" + seed + ",numSquares=" + numSquares + ",min=" + min + ",max=" + max + ')'; 119 } 120}