001/*
002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.example;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import org.tribuo.Example;
022import org.tribuo.classification.Label;
023import org.tribuo.impl.ArrayExample;
024
025import java.util.ArrayList;
026import java.util.List;
027
028/**
029 * A data source for two concentric circles, one per class.
030 */
031public final class ConcentricCirclesDataSource extends DemoLabelDataSource {
032
033    @Config(description = "The radius of the outer circle.")
034    private double radius = 2;
035
036    @Config(description = "The proportion of the circle radius that forms class one.")
037    private double classProportion = 0.5;
038
039    /**
040     * For OLCUT.
041     */
042    private ConcentricCirclesDataSource() {
043        super();
044    }
045
046    /**
047     * Constructs a data source for two concentric circles, one per class.
048     *
049     * @param numSamples      The number of samples to generate.
050     * @param seed            The RNG seed.
051     * @param radius          The radius of the outer circle.
052     * @param classProportion The proportion of the circle area that forms class 1.
053     */
054    public ConcentricCirclesDataSource(int numSamples, long seed, double radius, double classProportion) {
055        super(numSamples, seed);
056        this.radius = radius;
057        this.classProportion = classProportion;
058        postConfig();
059    }
060
061    /**
062     * Used by the OLCUT configuration system, and should not be called by external code.
063     */
064    @Override
065    public void postConfig() {
066        if ((classProportion <= 0.0) || (classProportion >= 1.0)) {
067            throw new PropertyException("", "classProportion", "Class proportion must be between zero and one, found " + classProportion);
068        }
069        if (radius <= 0) {
070            throw new PropertyException("", "radius", "Radius must be positive, found " + radius);
071        }
072        super.postConfig();
073    }
074
075    @Override
076    protected List<Example<Label>> generate() {
077        List<Example<Label>> list = new ArrayList<>();
078
079        for (int i = 0; i < numSamples; i++) {
080            double rotation = rng.nextDouble() * 2 * Math.PI;
081            double distance = Math.sqrt(rng.nextDouble()) * radius;
082            double[] values = new double[2];
083            values[0] = distance * Math.cos(rotation);
084            values[1] = distance * Math.sin(rotation);
085
086            double labelDistance = (values[0] * values[0]) + (values[1] * values[1]);
087            Label label;
088            if (labelDistance < classProportion * radius * radius) {
089                label = FIRST_CLASS;
090            } else {
091                label = SECOND_CLASS;
092            }
093
094            list.add(new ArrayExample<>(label, FEATURE_NAMES, values));
095        }
096
097        return list;
098    }
099
100    @Override
101    public String toString() {
102        return "ConcentricCircles(numSamples=" + numSamples + ",seed=" + seed + ",radius=" + radius + ",classProportion=" + classProportion + ")";
103    }
104}