package org.tribuo.clustering.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/clustering/example/GaussianClusterDataSource.class */
public final class GaussianClusterDataSource implements ConfigurableDataSource<ClusterID> {
    private static final ClusteringFactory factory = new ClusteringFactory();
    private static final String[] allFeatureNames = {"A", "B", "C", "D"};

    @Config(mandatory = true, description = "The number of samples to draw.")
    private int numSamples;

    @Config(description = "The probability of sampling from each Gaussian, must sum to 1.0.")
    private double[] mixingDistribution;

    @Config(description = "The mean of the first Gaussian.")
    private double[] firstMean;

    @Config(description = "A vector representing the first Gaussian's covariance matrix.")
    private double[] firstVariance;

    @Config(description = "The mean of the second Gaussian.")
    private double[] secondMean;

    @Config(description = "A vector representing the second Gaussian's covariance matrix.")
    private double[] secondVariance;

    @Config(description = "The mean of the third Gaussian.")
    private double[] thirdMean;

    @Config(description = "A vector representing the third Gaussian's covariance matrix.")
    private double[] thirdVariance;

    @Config(description = "The mean of the fourth Gaussian.")
    private double[] fourthMean;

    @Config(description = "A vector representing the fourth Gaussian's covariance matrix.")
    private double[] fourthVariance;

    @Config(description = "The mean of the fifth Gaussian.")
    private double[] fifthMean;

    @Config(description = "A vector representing the fifth Gaussian's covariance matrix.")
    private double[] fifthVariance;

    @Config(description = "The RNG seed.")
    private long seed;
    private List<Example<ClusterID>> examples;

    /* loaded from: input_file:org/tribuo/clustering/example/GaussianClusterDataSource$GaussianClusterDataSourceProvenance.class */
    public static final class GaussianClusterDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
        private static final long serialVersionUID = 1;

        GaussianClusterDataSourceProvenance(GaussianClusterDataSource gaussianClusterDataSource) {
            super(gaussianClusterDataSource, "DataSource");
        }

        public GaussianClusterDataSourceProvenance(Map<String, Provenance> map) {
            this(extractProvenanceInfo(map));
        }

        private GaussianClusterDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo extractedInfo) {
            super(extractedInfo);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap hashMap = new HashMap(map);
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(ObjectProvenance.checkAndExtractProvenance(hashMap, "class-name", StringProvenance.class, GaussianClusterDataSourceProvenance.class.getSimpleName()).getValue(), ObjectProvenance.checkAndExtractProvenance(hashMap, "host-short-name", StringProvenance.class, GaussianClusterDataSourceProvenance.class.getSimpleName()).getValue(), hashMap, Collections.emptyMap());
        }
    }

    private GaussianClusterDataSource() {
        this.mixingDistribution = new double[]{0.1d, 0.35d, 0.05d, 0.25d, 0.25d};
        this.firstMean = new double[]{0.0d, 0.0d};
        this.firstVariance = new double[]{1.0d, 0.0d, 0.0d, 1.0d};
        this.secondMean = new double[]{5.0d, 5.0d};
        this.secondVariance = new double[]{1.0d, 0.0d, 0.0d, 1.0d};
        this.thirdMean = new double[]{2.5d, 2.5d};
        this.thirdVariance = new double[]{1.0d, 0.5d, 0.5d, 1.0d};
        this.fourthMean = new double[]{10.0d, 0.0d};
        this.fourthVariance = new double[]{0.1d, 0.0d, 0.0d, 0.1d};
        this.fifthMean = new double[]{-1.0d, 0.0d};
        this.fifthVariance = new double[]{1.0d, 0.0d, 0.0d, 0.1d};
        this.seed = 12345L;
    }

    public GaussianClusterDataSource(int i, long j) {
        this.mixingDistribution = new double[]{0.1d, 0.35d, 0.05d, 0.25d, 0.25d};
        this.firstMean = new double[]{0.0d, 0.0d};
        this.firstVariance = new double[]{1.0d, 0.0d, 0.0d, 1.0d};
        this.secondMean = new double[]{5.0d, 5.0d};
        this.secondVariance = new double[]{1.0d, 0.0d, 0.0d, 1.0d};
        this.thirdMean = new double[]{2.5d, 2.5d};
        this.thirdVariance = new double[]{1.0d, 0.5d, 0.5d, 1.0d};
        this.fourthMean = new double[]{10.0d, 0.0d};
        this.fourthVariance = new double[]{0.1d, 0.0d, 0.0d, 0.1d};
        this.fifthMean = new double[]{-1.0d, 0.0d};
        this.fifthVariance = new double[]{1.0d, 0.0d, 0.0d, 0.1d};
        this.seed = 12345L;
        this.numSamples = i;
        this.seed = j;
        postConfig();
    }

    public GaussianClusterDataSource(int i, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, double[] dArr6, double[] dArr7, double[] dArr8, double[] dArr9, double[] dArr10, double[] dArr11, long j) {
        this.mixingDistribution = new double[]{0.1d, 0.35d, 0.05d, 0.25d, 0.25d};
        this.firstMean = new double[]{0.0d, 0.0d};
        this.firstVariance = new double[]{1.0d, 0.0d, 0.0d, 1.0d};
        this.secondMean = new double[]{5.0d, 5.0d};
        this.secondVariance = new double[]{1.0d, 0.0d, 0.0d, 1.0d};
        this.thirdMean = new double[]{2.5d, 2.5d};
        this.thirdVariance = new double[]{1.0d, 0.5d, 0.5d, 1.0d};
        this.fourthMean = new double[]{10.0d, 0.0d};
        this.fourthVariance = new double[]{0.1d, 0.0d, 0.0d, 0.1d};
        this.fifthMean = new double[]{-1.0d, 0.0d};
        this.fifthVariance = new double[]{1.0d, 0.0d, 0.0d, 0.1d};
        this.seed = 12345L;
        this.numSamples = i;
        this.mixingDistribution = dArr;
        this.firstMean = dArr2;
        this.firstVariance = dArr3;
        this.secondMean = dArr4;
        this.secondVariance = dArr5;
        this.thirdMean = dArr6;
        this.thirdVariance = dArr7;
        this.fourthMean = dArr8;
        this.fourthVariance = dArr9;
        this.fifthMean = dArr10;
        this.fifthVariance = dArr11;
        this.seed = j;
        postConfig();
    }

    public void postConfig() {
        if (this.numSamples < 1) {
            throw new PropertyException("", "numSamples", "numSamples must be positive, found " + this.numSamples);
        }
        if (this.mixingDistribution.length != 5) {
            throw new PropertyException("", "mixingDistribution", "mixingDistribution must have 5 elements, found " + this.mixingDistribution.length);
        }
        if (Math.abs(Util.sum(this.mixingDistribution) - 1.0d) > 1.0E-10d) {
            throw new PropertyException("", "mixingDistribution", "mixingDistribution must sum to 1.0, found " + Util.sum(this.mixingDistribution));
        }
        if (this.firstMean.length > allFeatureNames.length || this.firstMean.length == 0) {
            throw new PropertyException("", "firstMean", "Must have 1-4 features, found " + this.firstMean.length);
        }
        int length = this.firstMean.length * this.firstMean.length;
        if (this.firstVariance.length != length) {
            throw new PropertyException("", "firstVariance", "Invalid first covariance matrix, expected " + length + " elements, found " + this.firstVariance.length);
        }
        if (this.secondMean.length != this.firstMean.length) {
            throw new PropertyException("", "secondMean", "All Gaussians must have the same number of dimensions, expected " + this.firstMean.length + ", found " + this.secondMean.length);
        }
        if (this.secondVariance.length != this.firstVariance.length) {
            throw new PropertyException("", "secondVariance", "secondVariance is invalid, expected " + length + ", found " + this.secondVariance.length);
        }
        if (this.thirdMean.length != this.firstMean.length) {
            throw new PropertyException("", "thirdMean", "All Gaussians must have the same number of dimensions, expected " + this.firstMean.length + ", found " + this.thirdMean.length);
        }
        if (this.thirdVariance.length != this.firstVariance.length) {
            throw new PropertyException("", "thirdVariance", "thirdVariance is invalid, expected " + length + ", found " + this.thirdVariance.length);
        }
        if (this.fourthMean.length != this.firstMean.length) {
            throw new PropertyException("", "fourthMean", "All Gaussians must have the same number of dimensions, expected " + this.firstMean.length + ", found " + this.fourthMean.length);
        }
        if (this.fourthVariance.length != this.firstVariance.length) {
            throw new PropertyException("", "fourthVariance", "fourthVariance is invalid, expected " + length + ", found " + this.fourthVariance.length);
        }
        if (this.fifthMean.length != this.firstMean.length) {
            throw new PropertyException("", "fifthMean", "All Gaussians must have the same number of dimensions, expected " + this.firstMean.length + ", found " + this.fifthMean.length);
        }
        if (this.fifthVariance.length != this.firstVariance.length) {
            throw new PropertyException("", "fifthVariance", "fifthVariance is invalid, expected " + length + ", found " + this.fifthVariance.length);
        }
        for (int i = 0; i < this.mixingDistribution.length; i++) {
            if (this.mixingDistribution[i] < 0.0d) {
                throw new PropertyException("", "mixingDistribution", "Probability values in the mixing distribution must be non-negative, found " + Arrays.toString(this.mixingDistribution));
            }
        }
        double[] generateCDF = Util.generateCDF(this.mixingDistribution);
        String[] strArr = (String[]) Arrays.copyOf(allFeatureNames, this.firstMean.length);
        Random random = new Random(this.seed);
        MultivariateNormalDistribution[] multivariateNormalDistributionArr = {new MultivariateNormalDistribution(new JDKRandomGenerator(random.nextInt()), this.firstMean, reshapeAndValidate(this.firstVariance, "firstVariance")), new MultivariateNormalDistribution(new JDKRandomGenerator(random.nextInt()), this.secondMean, reshapeAndValidate(this.secondVariance, "secondVariance")), new MultivariateNormalDistribution(new JDKRandomGenerator(random.nextInt()), this.thirdMean, reshapeAndValidate(this.thirdVariance, "thirdVariance")), new MultivariateNormalDistribution(new JDKRandomGenerator(random.nextInt()), this.fourthMean, reshapeAndValidate(this.fourthVariance, "fourthVariance")), new MultivariateNormalDistribution(new JDKRandomGenerator(random.nextInt()), this.fifthMean, reshapeAndValidate(this.fifthVariance, "fifthVariance"))};
        ArrayList arrayList = new ArrayList(this.numSamples);
        for (int i2 = 0; i2 < this.numSamples; i2++) {
            int sampleFromCDF = Util.sampleFromCDF(generateCDF, random);
            arrayList.add(new ArrayExample(new ClusterID(sampleFromCDF), strArr, multivariateNormalDistributionArr[sampleFromCDF].sample()));
        }
        this.examples = Collections.unmodifiableList(arrayList);
    }

    public OutputFactory<ClusterID> getOutputFactory() {
        return factory;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public DataSourceProvenance m13getProvenance() {
        return new GaussianClusterDataSourceProvenance(this);
    }

    public Iterator<Example<ClusterID>> iterator() {
        return this.examples.iterator();
    }

    private static double[][] reshapeAndValidate(double[] dArr, String str) {
        int sqrt = (int) Math.sqrt(dArr.length);
        if (sqrt * sqrt != dArr.length) {
            throw new IllegalArgumentException("The vector does not represent a square matrix, found " + dArr.length + " elements, which is not square.");
        }
        double[][] dArr2 = new double[sqrt][sqrt];
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] < 0.0d) {
                throw new PropertyException("", str, str + " must have a non-negative covariance matrix, found " + Arrays.toString(dArr));
            }
            dArr2[i / sqrt][i % sqrt] = dArr[i];
        }
        return dArr2;
    }

    public static MutableDataset<ClusterID> generateDataset(int i, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, double[] dArr6, double[] dArr7, double[] dArr8, double[] dArr9, double[] dArr10, double[] dArr11, long j) {
        return new MutableDataset<>(new GaussianClusterDataSource(i, dArr, dArr2, dArr3, dArr4, dArr5, dArr6, dArr7, dArr8, dArr9, dArr10, dArr11, j));
    }
}
