package org.hipparchus.stat.fitting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.hipparchus.distribution.multivariate.MixtureMultivariateNormalDistribution;
import org.hipparchus.distribution.multivariate.MultivariateNormalDistribution;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.exception.MathIllegalStateException;
import org.hipparchus.linear.Array2DRowRealMatrix;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.util.Pair;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/hipparchus/stat/fitting/MultivariateNormalMixtureExpectationMaximizationTest.class */
public class MultivariateNormalMixtureExpectationMaximizationTest {
    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    @Test(expected = MathIllegalArgumentException.class)
    public void testNonEmptyData() {
        new MultivariateNormalMixtureExpectationMaximization((double[][]) new double[0]);
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Test(expected = MathIllegalArgumentException.class)
    public void testNonJaggedData() {
        new MultivariateNormalMixtureExpectationMaximization((double[][]) new double[]{new double[]{1.0d, 2.0d, 3.0d}, new double[]{4.0d, 5.0d, 6.0d, 7.0d}});
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Test(expected = MathIllegalArgumentException.class)
    public void testMultipleColumnsRequired() {
        new MultivariateNormalMixtureExpectationMaximization((double[][]) new double[]{new double[]{1.0d}, new double[]{2.0d}});
    }

    @Test(expected = MathIllegalArgumentException.class)
    public void testMaxIterationsPositive() {
        double[][] testSamples = getTestSamples();
        new MultivariateNormalMixtureExpectationMaximization(testSamples).fit(MultivariateNormalMixtureExpectationMaximization.estimate(testSamples, 2), 0, 1.0E-5d);
    }

    @Test(expected = MathIllegalArgumentException.class)
    public void testThresholdPositive() {
        double[][] testSamples = getTestSamples();
        new MultivariateNormalMixtureExpectationMaximization(testSamples).fit(MultivariateNormalMixtureExpectationMaximization.estimate(testSamples, 2), 1000, 0.0d);
    }

    @Test(expected = MathIllegalStateException.class)
    public void testConvergenceException() {
        double[][] testSamples = getTestSamples();
        new MultivariateNormalMixtureExpectationMaximization(testSamples).fit(MultivariateNormalMixtureExpectationMaximization.estimate(testSamples, 2), 5, 1.0E-5d);
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r5v12, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r5v16, types: [double[], double[][]] */
    @Test(expected = MathIllegalArgumentException.class)
    public void testIncompatibleIntialMixture() {
        double[] dArr = {0.5d, 0.5d};
        MultivariateNormalDistribution[] multivariateNormalDistributionArr = {new MultivariateNormalDistribution(new double[]{-0.0021722935000328823d, 3.5432892936887908d}, (double[][]) new double[]{new double[]{4.537422569229048d, 3.5266152281729304d}, new double[]{3.5266152281729304d, 6.175448814169779d}}), new MultivariateNormalDistribution(new double[]{5.090902706507635d, 8.68540656355283d}, (double[][]) new double[]{new double[]{2.886778573963039d, 1.5257474543463154d}, new double[]{1.5257474543463154d, 3.3794567673616918d}})};
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Pair(Double.valueOf(dArr[0]), multivariateNormalDistributionArr[0]));
        arrayList.add(new Pair(Double.valueOf(dArr[1]), multivariateNormalDistributionArr[1]));
        new MultivariateNormalMixtureExpectationMaximization((double[][]) new double[]{new double[]{1.0d, 2.0d, 3.0d}, new double[]{4.0d, 5.0d, 6.0d}, new double[]{7.0d, 8.0d, 9.0d}}).fit(new MixtureMultivariateNormalDistribution(arrayList));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r4v5, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r4v7, types: [double[], double[][]] */
    @Test
    public void testInitialMixture() {
        double[] dArr = {0.5d, 0.5d};
        double[] dArr2 = {new double[]{-0.0021722935000328823d, 3.5432892936887908d}, new double[]{5.090902706507635d, 8.68540656355283d}};
        RealMatrix[] realMatrixArr = {new Array2DRowRealMatrix((double[][]) new double[]{new double[]{4.537422569229048d, 3.5266152281729304d}, new double[]{3.5266152281729304d, 6.175448814169779d}}), new Array2DRowRealMatrix((double[][]) new double[]{new double[]{2.886778573963039d, 1.5257474543463154d}, new double[]{1.5257474543463154d, 3.3794567673616918d}})};
        MultivariateNormalDistribution[] multivariateNormalDistributionArr = {new MultivariateNormalDistribution(dArr2[0], realMatrixArr[0].getData()), new MultivariateNormalDistribution(dArr2[1], realMatrixArr[1].getData())};
        int i = 0;
        for (Pair pair : MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2).getComponents()) {
            Assert.assertEquals(dArr[i], ((Double) pair.getFirst()).doubleValue(), Math.ulp(1.0d));
            Assert.assertTrue(Arrays.equals(dArr2[i], ((MultivariateNormalDistribution) pair.getValue()).getMeans()));
            Assert.assertEquals(realMatrixArr[i], ((MultivariateNormalDistribution) pair.getValue()).getCovariances());
            i++;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r4v5, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r4v7, types: [double[], double[][]] */
    @Test
    public void testFit() {
        double[][] testSamples = getTestSamples();
        double[] dArr = {0.2962324189652912d, 0.7037675810347089d};
        double[] dArr2 = {new double[]{-1.4213112715121132d, 1.6924690505757753d}, new double[]{4.213612224374709d, 7.975621325853645d}};
        RealMatrix[] realMatrixArr = {new Array2DRowRealMatrix((double[][]) new double[]{new double[]{1.739356907285747d, -0.5867644251487614d}, new double[]{-0.5867644251487614d, 1.0232932029324642d}}), new Array2DRowRealMatrix((double[][]) new double[]{new double[]{4.245384898007161d, 2.5797798966382155d}, new double[]{2.5797798966382155d, 3.9200272522448367d}})};
        MultivariateNormalDistribution[] multivariateNormalDistributionArr = {new MultivariateNormalDistribution(dArr2[0], realMatrixArr[0].getData()), new MultivariateNormalDistribution(dArr2[1], realMatrixArr[1].getData())};
        MultivariateNormalMixtureExpectationMaximization multivariateNormalMixtureExpectationMaximization = new MultivariateNormalMixtureExpectationMaximization(testSamples);
        multivariateNormalMixtureExpectationMaximization.fit(MultivariateNormalMixtureExpectationMaximization.estimate(testSamples, 2));
        List<Pair> components = multivariateNormalMixtureExpectationMaximization.getFittedModel().getComponents();
        Assert.assertEquals(-4.292431006791994d, multivariateNormalMixtureExpectationMaximization.getLogLikelihood(), Math.ulp(1.0d));
        int i = 0;
        for (Pair pair : components) {
            double doubleValue = ((Double) pair.getFirst()).doubleValue();
            MultivariateNormalDistribution multivariateNormalDistribution = (MultivariateNormalDistribution) pair.getSecond();
            double[] means = multivariateNormalDistribution.getMeans();
            RealMatrix covariances = multivariateNormalDistribution.getCovariances();
            Assert.assertEquals(dArr[i], doubleValue, Math.ulp(1.0d));
            Assert.assertTrue(Arrays.equals(dArr2[i], means));
            Assert.assertEquals(realMatrixArr[i], covariances);
            i++;
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    private double[][] getTestSamples() {
        return new double[]{new double[]{7.358553610469948d, 11.31260831446758d}, new double[]{7.175770420124739d, 8.988812210204454d}, new double[]{4.324151905768422d, 6.837727899051482d}, new double[]{2.157832219173036d, 6.317444585521968d}, new double[]{-1.890157421896651d, 1.74271202875498d}, new double[]{0.8922409354455803d, 1.999119343923781d}, new double[]{3.396949764787055d, 6.813170372579068d}, new double[]{-2.057498232686068d, -0.002522983830852255d}, new double[]{6.359932157365045d, 8.343600029975851d}, new double[]{3.353102234276168d, 7.087541882898689d}, new double[]{-1.763877221595639d, 0.9688890460330644d}, new double[]{6.151457185125111d, 9.075011757431174d}, new double[]{4.281597398048899d, 5.953270070976117d}, new double[]{3.549576703974894d, 8.616038155992861d}, new double[]{6.004706732349854d, 8.95942339108747d}, new double[]{2.802915014676262d, 6.285676742173564d}, new double[]{-0.6029879029880616d, 1.083332958357485d}, new double[]{3.631827105398369d, 6.743428504049444d}, new double[]{6.161125014007315d, 9.60920569689001d}, new double[]{-1.049582894255342d, 0.2020017892080281d}, new double[]{3.910573022688315d, 8.19609909534937d}, new double[]{8.180454017634863d, 7.861055769719962d}, new double[]{1.488945440439716d, 8.02699903761247d}, new double[]{4.813750847823778d, 12.34416881332515d}, new double[]{0.0443208501259158d, 5.901148093240691d}, new double[]{4.416417235068346d, 4.465243084006094d}, new double[]{4.0002433603072d, 6.721937850166174d}, new double[]{3.190113818788205d, 10.51648348411058d}, new double[]{4.493600914967883d, 7.938224231022314d}, new double[]{-3.675669533266189d, 4.472845076673303d}, new double[]{6.648645511703989d, 12.03544085965724d}, new double[]{-1.330031331404445d, 1.33931042964811d}, new double[]{-3.812111460708707d, 2.50534195568356d}, new double[]{5.669339356648331d, 6.214488981177026d}, new double[]{1.006596727153816d, 1.51165463112716d}, new double[]{5.039466365033024d, 7.476532610478689d}, new double[]{4.349091929968925d, 7.446356406259756d}, new double[]{-1.220289665119069d, 3.403926955951437d}, new double[]{5.553003979122395d, 6.886518211202239d}, new double[]{2.274487732222856d, 7.009541508533196d}, new double[]{4.147567059965864d, 7.34025244349202d}, new double[]{4.083882618965819d, 6.362852861075623d}, new double[]{2.203122344647599d, 7.260295257904624d}, new double[]{-2.147497550770442d, 1.262293431529498d}, new double[]{2.473700950426512d, 6.558900135505638d}, new double[]{8.267081298847554d, 12.10214104577748d}, new double[]{6.91977329776865d, 9.91998488301285d}, new double[]{0.1680479852730894d, 6.28286034168897d}, new double[]{-1.268578659195158d, 2.326711221485755d}, new double[]{1.829966451374701d, 6.254187605304518d}, new double[]{5.648849025754848d, 9.33000204075029d}, new double[]{-2.302874793257666d, 3.585545172776065d}, new double[]{-2.629218791709046d, 2.156215538500288d}, new double[]{4.036618140700114d, 10.2962785719958d}, new double[]{0.4616386422783874d, 0.6782756325806778d}, new double[]{-0.3447896073408363d, 0.4999834691645118d}, new double[]{-0.475281453118318d, 1.931470384180492d}, new double[]{2.382509690609731d, 6.071782429815853d}, new double[]{-3.203934441889096d, 2.572079552602468d}, new double[]{8.465636032165087d, 13.96462998683518d}, new double[]{2.36755660870416d, 5.7844595007273d}, new double[]{0.5935496528993371d, 1.374615871358943d}, new double[]{-2.467481505748694d, 2.097224634713005d}, new double[]{4.27867444328542d, 10.24772361238549d}, new double[]{-2.013791907543137d, 2.013799426047639d}, new double[]{6.424588084404173d, 9.185334939684516d}, new double[]{-0.8448238876802175d, 0.5447382022282812d}, new double[]{1.342955703473923d, 8.645456317633556d}, new double[]{3.108712208751979d, 8.512156853800064d}, new double[]{4.343205178315472d, 8.056869549234374d}, new double[]{-2.971767642212396d, 3.201180146824761d}, new double[]{2.583820931523672d, 5.459873414473854d}, new double[]{4.209139115268925d, 8.171098193546225d}, new double[]{0.4064909057902746d, 1.454390775518743d}, new double[]{3.068642411145223d, 6.959485153620035d}, new double[]{6.085968972900461d, 7.391429799500965d}, new double[]{-1.342265795764202d, 1.454550012997143d}, new double[]{6.249773274516883d, 6.290269880772023d}, new double[]{4.986225847822566d, 7.75266344868907d}, new double[]{7.642443254378944d, 10.19914817500263d}, new double[]{6.438181159163673d, 8.464396764810347d}, new double[]{2.520859761025108d, 7.68222425260111d}, new double[]{2.883699944257541d, 6.777960331348503d}, new double[]{2.788004550956599d, 6.634735386652733d}, new double[]{3.331661231995638d, 5.794191300046592d}, new double[]{3.526172276645504d, 6.710802266815884d}, new double[]{3.188298528138741d, 10.34495528210205d}, new double[]{0.7345539486114623d, 5.807604004180681d}, new double[]{1.165044595880125d, 7.830121829295257d}, new double[]{7.146962523500671d, 11.62995162065415d}, new double[]{7.813872137162087d, 10.62827008714735d}, new double[]{3.118099164870063d, 8.28600314818637d}, new double[]{-1.708739286262571d, 1.561026755374264d}, new double[]{1.786163047580084d, 4.172394388214604d}, new double[]{3.718506403232386d, 7.807752990130349d}, new double[]{6.167414046828899d, 10.01104941031293d}, new double[]{-1.063477247689196d, 1.61176085846339d}, new double[]{-3.396739609433642d, 0.7127911050002151d}, new double[]{2.438885945896797d, 7.353011138689225d}, new double[]{-0.2073204144780931d, 0.850771146627012d}};
    }
}
