001/*
002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.example;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import org.tribuo.Example;
022import org.tribuo.classification.Label;
023import org.tribuo.impl.ArrayExample;
024
025import java.util.ArrayList;
026import java.util.List;
027
028/**
029 * Creates a data source using a 2d checkerboard of alternating classes.
030 */
031public final class CheckerboardDataSource extends DemoLabelDataSource {
032
033    @Config(description = "The number of squares on each side.")
034    private int numSquares = 5;
035
036    @Config(description = "The minimum feature value.")
037    private double min = 0.0;
038
039    @Config(description = "The maximum feature value.")
040    private double max = 10.0;
041
042    private double range;
043
044    private double tileWidth;
045
046    /**
047     * For OLCUT.
048     */
049    private CheckerboardDataSource() {
050        super();
051    }
052
053    /**
054     * Creates a checkboard with the required number of squares per dimension, where each feature value lies between min and max.
055     *
056     * @param numSamples The number of samples to generate.
057     * @param seed       The RNG seed.
058     * @param numSquares The number of squares.
059     * @param min        The minimum feature value.
060     * @param max        The maximum feature value.
061     */
062    public CheckerboardDataSource(int numSamples, long seed, int numSquares, double min, double max) {
063        super(numSamples, seed);
064        this.numSquares = numSquares;
065        this.min = min;
066        this.max = max;
067        postConfig();
068    }
069
070    /**
071     * Used by the OLCUT configuration system, and should not be called by external code.
072     */
073    @Override
074    public void postConfig() {
075        if (max <= min) {
076            throw new PropertyException("", "min", "min must be strictly less than max, min = " + min + ", max = " + max);
077        }
078        if (numSquares < 2) {
079            throw new PropertyException("", "numSquares", "numSquares must be 2 or greater, found " + numSquares);
080        }
081        range = Math.abs(max - min);
082        tileWidth = range / numSquares;
083        super.postConfig();
084    }
085
086    @Override
087    protected List<Example<Label>> generate() {
088        List<Example<Label>> list = new ArrayList<>();
089
090        for (int i = 0; i < numSamples; i++) {
091            double[] values = new double[2];
092            values[0] = (rng.nextDouble() * range);
093            values[1] = (rng.nextDouble() * range);
094
095            int modX1 = ((int) Math.floor(values[0] / tileWidth)) % 2;
096            int modX2 = ((int) Math.floor(values[1] / tileWidth)) % 2;
097
098            Label label;
099            if (modX1 == modX2) {
100                label = FIRST_CLASS;
101            } else {
102                label = SECOND_CLASS;
103            }
104
105            // Update the minimums after computing the label so we don't have to
106            // deal with tricky negative issues interacting with Math.floor().
107            values[0] += min;
108            values[1] += min;
109
110            list.add(new ArrayExample<>(label, FEATURE_NAMES, values));
111        }
112
113        return list;
114    }
115
116    @Override
117    public String toString() {
118        return "Checkerboard(numSamples=" + numSamples + ",seed=" + seed + ",numSquares=" + numSquares + ",min=" + min + ",max=" + max + ')';
119    }
120}