001/* 002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.example; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import org.tribuo.Example; 022import org.tribuo.classification.Label; 023import org.tribuo.impl.ArrayExample; 024 025import java.util.ArrayList; 026import java.util.Arrays; 027import java.util.List; 028import java.util.Random; 029 030/** 031 * A data source for two classes generated from separate Gaussians. 032 */ 033public final class GaussianLabelDataSource extends DemoLabelDataSource { 034 035 @Config(mandatory = true, description = "2d mean of the first Gaussian.") 036 private double[] firstMean; 037 038 @Config(mandatory = true, description = "4 element covariance matrix of the first Gaussian.") 039 private double[] firstCovarianceMatrix; 040 041 @Config(mandatory = true, description = "2d mean of the second Gaussian.") 042 private double[] secondMean; 043 044 @Config(mandatory = true, description = "4 element covariance matrix of the second Gaussian.") 045 private double[] secondCovarianceMatrix; 046 047 private double[] firstCholesky; 048 049 private double[] secondCholesky; 050 051 /** 052 * For OLCUT. 053 */ 054 private GaussianLabelDataSource() { 055 super(); 056 } 057 058 /** 059 * Constructs a data source which contains two classes where each class is sampled from a 2d Gaussian with 060 * the specified parameters. 061 * 062 * @param numSamples The number of samples to draw. 063 * @param seed The RNG seed. 064 * @param firstMean The mean of class one's Gaussian. 065 * @param firstCovarianceMatrix The covariance matrix of class one's Gaussian. 066 * @param secondMean The mean of class two's Gaussian. 067 * @param secondCovarianceMatrix The covariance matrix of class two's Gaussian. 068 */ 069 public GaussianLabelDataSource(int numSamples, long seed, double[] firstMean, double[] firstCovarianceMatrix, double[] secondMean, double[] secondCovarianceMatrix) { 070 super(numSamples, seed); 071 this.firstMean = firstMean; 072 this.firstCovarianceMatrix = firstCovarianceMatrix; 073 this.secondMean = secondMean; 074 this.secondCovarianceMatrix = secondCovarianceMatrix; 075 postConfig(); 076 } 077 078 /** 079 * Used by the OLCUT configuration system, and should not be called by external code. 080 */ 081 @Override 082 public void postConfig() { 083 if (firstMean.length != 2) { 084 throw new PropertyException("", "firstMean", "firstMean is not the right length"); 085 } 086 if (secondMean.length != 2) { 087 throw new PropertyException("", "secondMean", "secondMean is not the right length"); 088 } 089 if (firstCovarianceMatrix.length != 4) { 090 throw new PropertyException("", "firstCovarianceMatrix", "firstCovarianceMatrix is not the right length"); 091 } 092 if (secondCovarianceMatrix.length != 4) { 093 throw new PropertyException("", "secondCovarianceMatrix", "secondCovarianceMatrix is not the right length"); 094 } 095 096 for (int i = 0; i < firstCovarianceMatrix.length; i++) { 097 if (firstCovarianceMatrix[i] < 0) { 098 throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not positive semi-definite"); 099 } 100 if (secondCovarianceMatrix[i] < 0) { 101 throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not positive semi-definite"); 102 } 103 } 104 105 if (firstCovarianceMatrix[1] != firstCovarianceMatrix[2]) { 106 throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not a covariance matrix"); 107 } 108 109 if (secondCovarianceMatrix[1] != secondCovarianceMatrix[2]) { 110 throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not a covariance matrix"); 111 } 112 113 firstCholesky = new double[3]; 114 firstCholesky[0] = Math.sqrt(firstCovarianceMatrix[0]); 115 firstCholesky[1] = firstCovarianceMatrix[1] / Math.sqrt(firstCovarianceMatrix[0]); 116 firstCholesky[2] = Math.sqrt(firstCovarianceMatrix[3] * firstCovarianceMatrix[0] - firstCovarianceMatrix[1] * firstCovarianceMatrix[1]) / Math.sqrt(firstCovarianceMatrix[0]); 117 118 secondCholesky = new double[3]; 119 secondCholesky[0] = Math.sqrt(secondCovarianceMatrix[0]); 120 secondCholesky[1] = secondCovarianceMatrix[1] / Math.sqrt(secondCovarianceMatrix[0]); 121 secondCholesky[2] = Math.sqrt(secondCovarianceMatrix[3] * secondCovarianceMatrix[0] - secondCovarianceMatrix[1] * secondCovarianceMatrix[1]) / Math.sqrt(secondCovarianceMatrix[0]); 122 super.postConfig(); 123 } 124 125 @Override 126 protected List<Example<Label>> generate() { 127 List<Example<Label>> list = new ArrayList<>(); 128 129 for (int i = 0; i < numSamples / 2; i++) { 130 double[] sample = sampleGaussian(rng, firstMean, firstCholesky); 131 ArrayExample<Label> datapoint = new ArrayExample<>(FIRST_CLASS, FEATURE_NAMES, sample); 132 list.add(datapoint); 133 } 134 135 for (int i = numSamples / 2; i < numSamples; i++) { 136 double[] sample = sampleGaussian(rng, secondMean, secondCholesky); 137 ArrayExample<Label> datapoint = new ArrayExample<>(SECOND_CLASS, FEATURE_NAMES, sample); 138 list.add(datapoint); 139 } 140 141 return list; 142 } 143 144 /** 145 * Samples from a 2d Gaussian specified by the mean vector and the Cholesky factorization. 146 * 147 * @param rng The RNG to use. 148 * @param means The mean of the Gaussian. 149 * @param cholesky The Cholesky factorization. 150 * @return A sample from a 2d Gaussian. 151 */ 152 private static double[] sampleGaussian(Random rng, double[] means, double[] cholesky) { 153 double[] sample = new double[2]; 154 155 double first = rng.nextGaussian(); 156 sample[0] = means[0] + first * cholesky[0]; 157 double second = rng.nextGaussian(); 158 sample[1] = means[1] + (first * cholesky[1]) + (second * cholesky[2]); 159 160 return sample; 161 } 162 163 @Override 164 public String toString() { 165 String sb = "GaussianGenerator(numSamples=" + 166 numSamples + 167 ",seed=" + 168 seed + 169 ",firstMean=" + 170 Arrays.toString(firstMean) + 171 ",firstCovarianceMatrix=" + 172 Arrays.toString(firstCovarianceMatrix) + 173 ",secondMean=" + 174 Arrays.toString(secondMean) + 175 ",secondCovarianceMatrix=" + 176 Arrays.toString(secondCovarianceMatrix) + 177 ')'; 178 179 return sb; 180 } 181}