001/* 002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.example; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import org.tribuo.Example; 022import org.tribuo.classification.Label; 023import org.tribuo.impl.ArrayExample; 024 025import java.util.ArrayList; 026import java.util.List; 027 028/** 029 * A data source for two concentric circles, one per class. 030 */ 031public final class ConcentricCirclesDataSource extends DemoLabelDataSource { 032 033 @Config(description = "The radius of the outer circle.") 034 private double radius = 2; 035 036 @Config(description = "The proportion of the circle radius that forms class one.") 037 private double classProportion = 0.5; 038 039 /** 040 * For OLCUT. 041 */ 042 private ConcentricCirclesDataSource() { 043 super(); 044 } 045 046 /** 047 * Constructs a data source for two concentric circles, one per class. 048 * 049 * @param numSamples The number of samples to generate. 050 * @param seed The RNG seed. 051 * @param radius The radius of the outer circle. 052 * @param classProportion The proportion of the circle area that forms class 1. 053 */ 054 public ConcentricCirclesDataSource(int numSamples, long seed, double radius, double classProportion) { 055 super(numSamples, seed); 056 this.radius = radius; 057 this.classProportion = classProportion; 058 postConfig(); 059 } 060 061 /** 062 * Used by the OLCUT configuration system, and should not be called by external code. 063 */ 064 @Override 065 public void postConfig() { 066 if ((classProportion <= 0.0) || (classProportion >= 1.0)) { 067 throw new PropertyException("", "classProportion", "Class proportion must be between zero and one, found " + classProportion); 068 } 069 if (radius <= 0) { 070 throw new PropertyException("", "radius", "Radius must be positive, found " + radius); 071 } 072 super.postConfig(); 073 } 074 075 @Override 076 protected List<Example<Label>> generate() { 077 List<Example<Label>> list = new ArrayList<>(); 078 079 for (int i = 0; i < numSamples; i++) { 080 double rotation = rng.nextDouble() * 2 * Math.PI; 081 double distance = Math.sqrt(rng.nextDouble()) * radius; 082 double[] values = new double[2]; 083 values[0] = distance * Math.cos(rotation); 084 values[1] = distance * Math.sin(rotation); 085 086 double labelDistance = (values[0] * values[0]) + (values[1] * values[1]); 087 Label label; 088 if (labelDistance < classProportion * radius * radius) { 089 label = FIRST_CLASS; 090 } else { 091 label = SECOND_CLASS; 092 } 093 094 list.add(new ArrayExample<>(label, FEATURE_NAMES, values)); 095 } 096 097 return list; 098 } 099 100 @Override 101 public String toString() { 102 return "ConcentricCircles(numSamples=" + numSamples + ",seed=" + seed + ",radius=" + radius + ",classProportion=" + classProportion + ")"; 103 } 104}