001/*
002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.example;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import org.tribuo.Example;
022import org.tribuo.classification.Label;
023import org.tribuo.impl.ArrayExample;
024
025import java.util.ArrayList;
026import java.util.List;
027
028/**
029 * A data source of two interleaved half circles with some zero mean Gaussian noise applied to each point.
030 */
031public final class NoisyInterlockingCrescentsDataSource extends DemoLabelDataSource {
032
033    @Config(description = "Variance of the Gaussian noise")
034    private double variance = 0.1;
035
036    /**
037     * For OLCUT.
038     */
039    private NoisyInterlockingCrescentsDataSource() {
040        super();
041    }
042
043    /**
044     * Constructs a noisy interlocking crescents data source.
045     * <p>
046     * It's the same as {@link InterlockingCrescentsDataSource} but each point has Gaussian
047     * noise with zero mean and the specified variance added to it.
048     *
049     * @param numSamples The number of samples to generate.
050     * @param seed       The RNG seed.
051     * @param variance   The variance of the Gaussian noise.
052     */
053    public NoisyInterlockingCrescentsDataSource(int numSamples, long seed, double variance) {
054        super(numSamples, seed);
055        this.variance = variance;
056        postConfig();
057    }
058
059    /**
060     * Used by the OLCUT configuration system, and should not be called by external code.
061     */
062    @Override
063    public void postConfig() {
064        if (variance <= 0.0) {
065            throw new PropertyException("", "variance", "Variance must be positive, found " + variance);
066        }
067        super.postConfig();
068    }
069
070    @Override
071    protected List<Example<Label>> generate() {
072        List<Example<Label>> list = new ArrayList<>();
073
074        for (int i = 0; i < numSamples / 2; i++) {
075            double[] values = new double[2];
076            double u = rng.nextDouble();
077            values[0] = Math.cos(Math.PI * u) + rng.nextGaussian() * variance;
078            values[1] = Math.sin(Math.PI * u) + rng.nextGaussian() * variance;
079            list.add(new ArrayExample<>(FIRST_CLASS, FEATURE_NAMES, values));
080        }
081
082        for (int i = numSamples / 2; i < numSamples; i++) {
083            double[] values = new double[2];
084            double u = rng.nextDouble();
085            values[0] = (1 - Math.cos(Math.PI * u)) + rng.nextGaussian() * variance;
086            values[1] = (0.5 - Math.sin(Math.PI * u)) + rng.nextGaussian() * variance;
087            list.add(new ArrayExample<>(SECOND_CLASS, FEATURE_NAMES, values));
088        }
089
090        return list;
091    }
092
093    @Override
094    public String toString() {
095        return "NoisyInterlockingCrescents(numSamples=" + numSamples + ",seed=" + seed + ",variance=" + variance + ')';
096    }
097}