001/*
002 * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.example;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import org.tribuo.Example;
022import org.tribuo.classification.Label;
023import org.tribuo.impl.ArrayExample;
024
025import java.util.ArrayList;
026import java.util.Arrays;
027import java.util.List;
028import java.util.Random;
029
030/**
031 * A data source for two classes generated from separate Gaussians.
032 */
033public final class GaussianLabelDataSource extends DemoLabelDataSource {
034
035    @Config(mandatory = true, description = "2d mean of the first Gaussian.")
036    private double[] firstMean;
037
038    @Config(mandatory = true, description = "4 element covariance matrix of the first Gaussian.")
039    private double[] firstCovarianceMatrix;
040
041    @Config(mandatory = true, description = "2d mean of the second Gaussian.")
042    private double[] secondMean;
043
044    @Config(mandatory = true, description = "4 element covariance matrix of the second Gaussian.")
045    private double[] secondCovarianceMatrix;
046
047    private double[] firstCholesky;
048
049    private double[] secondCholesky;
050
051    /**
052     * For OLCUT.
053     */
054    private GaussianLabelDataSource() {
055        super();
056    }
057
058    /**
059     * Constructs a data source which contains two classes where each class is sampled from a 2d Gaussian with
060     * the specified parameters.
061     *
062     * @param numSamples             The number of samples to draw.
063     * @param seed                   The RNG seed.
064     * @param firstMean              The mean of class one's Gaussian.
065     * @param firstCovarianceMatrix  The covariance matrix of class one's Gaussian.
066     * @param secondMean             The mean of class two's Gaussian.
067     * @param secondCovarianceMatrix The covariance matrix of class two's Gaussian.
068     */
069    public GaussianLabelDataSource(int numSamples, long seed, double[] firstMean, double[] firstCovarianceMatrix, double[] secondMean, double[] secondCovarianceMatrix) {
070        super(numSamples, seed);
071        this.firstMean = firstMean;
072        this.firstCovarianceMatrix = firstCovarianceMatrix;
073        this.secondMean = secondMean;
074        this.secondCovarianceMatrix = secondCovarianceMatrix;
075        postConfig();
076    }
077
078    /**
079     * Used by the OLCUT configuration system, and should not be called by external code.
080     */
081    @Override
082    public void postConfig() {
083        if (firstMean.length != 2) {
084            throw new PropertyException("", "firstMean", "firstMean is not the right length");
085        }
086        if (secondMean.length != 2) {
087            throw new PropertyException("", "secondMean", "secondMean is not the right length");
088        }
089        if (firstCovarianceMatrix.length != 4) {
090            throw new PropertyException("", "firstCovarianceMatrix", "firstCovarianceMatrix is not the right length");
091        }
092        if (secondCovarianceMatrix.length != 4) {
093            throw new PropertyException("", "secondCovarianceMatrix", "secondCovarianceMatrix is not the right length");
094        }
095
096        for (int i = 0; i < firstCovarianceMatrix.length; i++) {
097            if (firstCovarianceMatrix[i] < 0) {
098                throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not positive semi-definite");
099            }
100            if (secondCovarianceMatrix[i] < 0) {
101                throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not positive semi-definite");
102            }
103        }
104
105        if (firstCovarianceMatrix[1] != firstCovarianceMatrix[2]) {
106            throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not a covariance matrix");
107        }
108
109        if (secondCovarianceMatrix[1] != secondCovarianceMatrix[2]) {
110            throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not a covariance matrix");
111        }
112
113        firstCholesky = new double[3];
114        firstCholesky[0] = Math.sqrt(firstCovarianceMatrix[0]);
115        firstCholesky[1] = firstCovarianceMatrix[1] / Math.sqrt(firstCovarianceMatrix[0]);
116        firstCholesky[2] = Math.sqrt(firstCovarianceMatrix[3] * firstCovarianceMatrix[0] - firstCovarianceMatrix[1] * firstCovarianceMatrix[1]) / Math.sqrt(firstCovarianceMatrix[0]);
117
118        secondCholesky = new double[3];
119        secondCholesky[0] = Math.sqrt(secondCovarianceMatrix[0]);
120        secondCholesky[1] = secondCovarianceMatrix[1] / Math.sqrt(secondCovarianceMatrix[0]);
121        secondCholesky[2] = Math.sqrt(secondCovarianceMatrix[3] * secondCovarianceMatrix[0] - secondCovarianceMatrix[1] * secondCovarianceMatrix[1]) / Math.sqrt(secondCovarianceMatrix[0]);
122        super.postConfig();
123    }
124
125    @Override
126    protected List<Example<Label>> generate() {
127        List<Example<Label>> list = new ArrayList<>();
128
129        for (int i = 0; i < numSamples / 2; i++) {
130            double[] sample = sampleGaussian(rng, firstMean, firstCholesky);
131            ArrayExample<Label> datapoint = new ArrayExample<>(FIRST_CLASS, FEATURE_NAMES, sample);
132            list.add(datapoint);
133        }
134
135        for (int i = numSamples / 2; i < numSamples; i++) {
136            double[] sample = sampleGaussian(rng, secondMean, secondCholesky);
137            ArrayExample<Label> datapoint = new ArrayExample<>(SECOND_CLASS, FEATURE_NAMES, sample);
138            list.add(datapoint);
139        }
140
141        return list;
142    }
143
144    /**
145     * Samples from a 2d Gaussian specified by the mean vector and the Cholesky factorization.
146     *
147     * @param rng      The RNG to use.
148     * @param means    The mean of the Gaussian.
149     * @param cholesky The Cholesky factorization.
150     * @return A sample from a 2d Gaussian.
151     */
152    private static double[] sampleGaussian(Random rng, double[] means, double[] cholesky) {
153        double[] sample = new double[2];
154
155        double first = rng.nextGaussian();
156        sample[0] = means[0] + first * cholesky[0];
157        double second = rng.nextGaussian();
158        sample[1] = means[1] + (first * cholesky[1]) + (second * cholesky[2]);
159
160        return sample;
161    }
162
163    @Override
164    public String toString() {
165        String sb = "GaussianGenerator(numSamples=" +
166                numSamples +
167                ",seed=" +
168                seed +
169                ",firstMean=" +
170                Arrays.toString(firstMean) +
171                ",firstCovarianceMatrix=" +
172                Arrays.toString(firstCovarianceMatrix) +
173                ",secondMean=" +
174                Arrays.toString(secondMean) +
175                ",secondCovarianceMatrix=" +
176                Arrays.toString(secondCovarianceMatrix) +
177                ')';
178
179        return sb;
180    }
181}