package org.tribuo.clustering.example;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Random;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.impl.ArrayExample;
import org.tribuo.math.distributions.MultivariateNormalDistribution;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/clustering/example/ClusteringDataGenerator.class */
public abstract class ClusteringDataGenerator {
    private static ClusteringFactory clusteringFactory = new ClusteringFactory();

    /* JADX WARN: Type inference failed for: r3v10, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v14, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v18, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v22, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v26, types: [double[], double[][]] */
    public static Dataset<ClusterID> gaussianClusters(long j, long j2) {
        if (j < 1) {
            throw new IllegalArgumentException("Size must be a positive number, received " + j);
        }
        Random random = new Random(j2);
        String[] strArr = {"A", "B"};
        double[] generateCDF = Util.generateCDF(new double[]{0.1d, 0.35d, 0.05d, 0.25d, 0.25d});
        MultivariateNormalDistribution[] multivariateNormalDistributionArr = {new MultivariateNormalDistribution(new double[]{0.0d, 0.0d}, (double[][]) new double[]{new double[]{1.0d, 0.0d}, new double[]{0.0d, 1.0d}}, random.nextInt(), true), new MultivariateNormalDistribution(new double[]{5.0d, 5.0d}, (double[][]) new double[]{new double[]{1.0d, 0.0d}, new double[]{0.0d, 1.0d}}, random.nextInt(), true), new MultivariateNormalDistribution(new double[]{2.5d, 2.5d}, (double[][]) new double[]{new double[]{1.0d, 0.5d}, new double[]{0.5d, 1.0d}}, random.nextInt(), true), new MultivariateNormalDistribution(new double[]{10.0d, 0.0d}, (double[][]) new double[]{new double[]{0.1d, 0.0d}, new double[]{0.0d, 0.1d}}, random.nextInt(), true), new MultivariateNormalDistribution(new double[]{-1.0d, 0.0d}, (double[][]) new double[]{new double[]{1.0d, 0.0d}, new double[]{0.0d, 0.1d}}, random.nextInt(), true)};
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < j; i++) {
            int sampleFromCDF = Util.sampleFromCDF(generateCDF, random);
            arrayList.add(new ArrayExample(new ClusterID(sampleFromCDF), strArr, multivariateNormalDistributionArr[sampleFromCDF].sampleArray()));
        }
        return new MutableDataset(new ListDataSource(arrayList, clusteringFactory, new SimpleDataSourceProvenance("Generated clustering data", clusteringFactory)));
    }

    public static Pair<Dataset<ClusterID>, Dataset<ClusterID>> denseTrainTest() {
        return denseTrainTest(-1.0d);
    }

    public static Pair<Dataset<ClusterID>, Dataset<ClusterID>> denseTrainTest(double d) {
        MutableDataset mutableDataset = new MutableDataset(new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(), clusteringFactory), clusteringFactory);
        String[] strArr = {"A", "B", "C", "D"};
        mutableDataset.add(new ArrayExample(new ClusterID(1), strArr, new double[]{1.0d, 0.5d, 1.0d, d * 1.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(1), (String[]) strArr.clone(), new double[]{1.5d, 0.35d, 1.3d, d * 1.2d}));
        mutableDataset.add(new ArrayExample(new ClusterID(1), (String[]) strArr.clone(), new double[]{1.2d, 0.45d, 1.5d, d * 1.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(2), (String[]) strArr.clone(), new double[]{d * 1.1d, 0.55d, d * 1.5d, 0.5d}));
        mutableDataset.add(new ArrayExample(new ClusterID(2), (String[]) strArr.clone(), new double[]{d * 1.5d, 0.25d, d * 1.0d, 0.125d}));
        mutableDataset.add(new ArrayExample(new ClusterID(2), (String[]) strArr.clone(), new double[]{d * 1.0d, 0.5d, d * 1.123d, 0.123d}));
        mutableDataset.add(new ArrayExample(new ClusterID(3), (String[]) strArr.clone(), new double[]{1.5d, 5.0d, 0.5d, 4.5d}));
        mutableDataset.add(new ArrayExample(new ClusterID(3), (String[]) strArr.clone(), new double[]{1.234d, 5.1235d, 0.1235d, 6.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(3), (String[]) strArr.clone(), new double[]{1.734d, 4.5d, 0.5123d, 5.5d}));
        mutableDataset.add(new ArrayExample(new ClusterID(0), (String[]) strArr.clone(), new double[]{d * 1.0d, 0.25d, 5.0d, 10.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(0), (String[]) strArr.clone(), new double[]{d * 1.4d, 0.55d, 5.65d, 12.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(0), (String[]) strArr.clone(), new double[]{d * 1.9d, 0.25d, 5.9d, 15.0d}));
        MutableDataset mutableDataset2 = new MutableDataset(new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(), clusteringFactory), clusteringFactory);
        mutableDataset2.add(new ArrayExample(new ClusterID(1), (String[]) strArr.clone(), new double[]{2.0d, 0.45d, 3.5d, d * 2.0d}));
        mutableDataset2.add(new ArrayExample(new ClusterID(2), (String[]) strArr.clone(), new double[]{d * 2.0d, 0.55d, d * 2.5d, 2.5d}));
        mutableDataset2.add(new ArrayExample(new ClusterID(3), (String[]) strArr.clone(), new double[]{1.75d, 5.0d, 1.0d, 6.5d}));
        mutableDataset2.add(new ArrayExample(new ClusterID(0), (String[]) strArr.clone(), new double[]{d * 1.5d, 0.25d, 5.0d, 20.0d}));
        return new Pair<>(mutableDataset, mutableDataset2);
    }

    public static Pair<Dataset<ClusterID>, Dataset<ClusterID>> sparseTrainTest() {
        return sparseTrainTest(-1.0d);
    }

    public static Pair<Dataset<ClusterID>, Dataset<ClusterID>> sparseTrainTest(double d) {
        MutableDataset mutableDataset = new MutableDataset(new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(), clusteringFactory), clusteringFactory);
        mutableDataset.add(new ArrayExample(new ClusterID(1), new String[]{"A", "B", "C", "D"}, new double[]{1.0d, 0.5d, 1.0d, d * 1.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(1), new String[]{"B", "D", "F", "H"}, new double[]{1.5d, 0.35d, 1.3d, d * 1.2d}));
        mutableDataset.add(new ArrayExample(new ClusterID(1), new String[]{"A", "J", "D", "M"}, new double[]{1.2d, 0.45d, 1.5d, d * 1.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(2), new String[]{"C", "E", "F", "H"}, new double[]{d * 1.1d, 0.55d, d * 1.5d, 0.5d}));
        mutableDataset.add(new ArrayExample(new ClusterID(2), new String[]{"E", "G", "F", "I"}, new double[]{d * 1.5d, 0.25d, d * 1.0d, 0.125d}));
        mutableDataset.add(new ArrayExample(new ClusterID(2), new String[]{"J", "K", "C", "E"}, new double[]{d * 1.0d, 0.5d, d * 1.123d, 0.123d}));
        mutableDataset.add(new ArrayExample(new ClusterID(3), new String[]{"E", "A", "K", "J"}, new double[]{1.5d, 5.0d, 0.5d, 4.5d}));
        mutableDataset.add(new ArrayExample(new ClusterID(3), new String[]{"B", "C", "E", "H"}, new double[]{1.234d, 5.1235d, 0.1235d, 6.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(3), new String[]{"A", "M", "I", "J"}, new double[]{1.734d, 4.5d, 0.5123d, 5.5d}));
        mutableDataset.add(new ArrayExample(new ClusterID(0), new String[]{"Z", "A", "B", "C"}, new double[]{d * 1.0d, 0.25d, 5.0d, 10.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(0), new String[]{"K", "V", "E", "D"}, new double[]{d * 1.4d, 0.55d, 5.65d, 12.0d}));
        mutableDataset.add(new ArrayExample(new ClusterID(0), new String[]{"B", "G", "E", "A"}, new double[]{d * 1.9d, 0.25d, 5.9d, 15.0d}));
        MutableDataset mutableDataset2 = new MutableDataset(new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(), clusteringFactory), clusteringFactory);
        mutableDataset2.add(new ArrayExample(new ClusterID(1), new String[]{"AA", "B", "C", "D"}, new double[]{2.0d, 0.45d, 3.5d, d * 2.0d}));
        mutableDataset2.add(new ArrayExample(new ClusterID(2), new String[]{"B", "BB", "F", "E"}, new double[]{d * 2.0d, 0.55d, d * 2.5d, 2.5d}));
        mutableDataset2.add(new ArrayExample(new ClusterID(3), new String[]{"B", "E", "G", "H"}, new double[]{1.75d, 5.0d, 1.0d, 6.5d}));
        mutableDataset2.add(new ArrayExample(new ClusterID(0), new String[]{"B", "CC", "DD", "EE"}, new double[]{d * 1.5d, 0.25d, 5.0d, 20.0d}));
        return new Pair<>(mutableDataset, mutableDataset2);
    }

    public static Example<ClusterID> invalidSparseExample() {
        return new ArrayExample(new ClusterID(1), new String[]{"1", "5", "8"}, new double[]{1.0d, 5.0d, 8.0d});
    }

    public static Example<ClusterID> emptyExample() {
        return new ArrayExample(new ClusterID(1), new String[0], new double[0]);
    }
}
