package mikera.matrixx.decompose.impl.svd;

import mikera.matrixx.Matrix;
import mikera.matrixx.algo.Multiplications;
import org.junit.Assert;

/* loaded from: input_file:mikera/matrixx/decompose/impl/svd/StandardSvdChecks.class */
public abstract class StandardSvdChecks {
    private static double EPS = Math.pow(2.0d, -52.0d);
    boolean omitVerySmallValues = false;

    public abstract SvdImplicitQr createSvd();

    public void allTests() {
        testDecompositionOfTrivial();
        testWide();
        testTall();
        checkGetU_Transpose();
        if (!this.omitVerySmallValues) {
            testVerySmallValue();
        }
        testZero();
        testLargeToSmall();
        testIdentity();
        testLots();
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v17, types: [double[], double[][]] */
    public void testDecompositionOfTrivial() {
        Matrix create = Matrix.create((double[][]) new double[]{new double[]{5.0d, 2.0d, 3.0d}, new double[]{1.5d, -2.0d, 8.0d}, new double[]{-3.0d, 4.7d, -0.5d}});
        SvdImplicitQr createSvd = createSvd();
        Assert.assertNotNull(createSvd._decompose(create));
        Assert.assertEquals(3L, rank(createSvd, EPS));
        Assert.assertEquals(0L, nullity(createSvd, EPS));
        double[] doubleArray = createSvd.getSingularValues().toDoubleArray();
        checkNumFound(1, 1.0E-5d, 9.59186d, doubleArray);
        checkNumFound(1, 1.0E-5d, 5.18005d, doubleArray);
        checkNumFound(1, 1.0E-5d, 4.55558d, doubleArray);
        checkComponents(createSvd, create);
        Matrix create2 = Matrix.create((double[][]) new double[]{new double[]{1.0d, 2.0d, 3.0d}, new double[]{4.0d, 5.0d, 6.0d}, new double[]{7.0d, 8.0d, 9.0d}});
        SvdImplicitQr createSvd2 = createSvd();
        Assert.assertNotNull(createSvd2._decompose(create2));
        Assert.assertEquals(2L, rank(createSvd2, 10.0d * EPS));
        Assert.assertEquals(0L, nullity(createSvd2, EPS));
        double[] doubleArray2 = createSvd2.getSingularValues().toDoubleArray();
        checkNumFound(1, 1.0E-5d, 16.848103d, doubleArray2);
        checkNumFound(1, 1.0E-5d, 1.06837d, doubleArray2);
        checkNumFound(1, 1.0E-5d, 0.0d, doubleArray2);
        checkComponents(createSvd2, create2);
    }

    public void testWide() {
        Matrix createRandom = Matrix.createRandom(1, 1);
        createRandom.sub(0.5d);
        createRandom.scale(2.0d);
        SvdImplicitQr createSvd = createSvd();
        Assert.assertNotNull(createSvd._decompose(createRandom));
        checkComponents(createSvd, createRandom);
    }

    public void testTall() {
        Matrix createRandom = Matrix.createRandom(21, 5);
        createRandom.sub(0.5d);
        createRandom.scale(2.0d);
        SvdImplicitQr createSvd = createSvd();
        Assert.assertNotNull(createSvd._decompose(createRandom));
        checkComponents(createSvd, createRandom);
    }

    public void testZero() {
        for (int i = 1; i <= 11; i += 5) {
            for (int i2 = 1; i2 <= 11; i2 += 5) {
                Matrix create = Matrix.create(i, i2);
                SvdImplicitQr createSvd = createSvd();
                Assert.assertNotNull(createSvd._decompose(create));
                Assert.assertEquals(Math.min(i, i2), checkOccurrence(0.0d, createSvd.getSingularValues().toDoubleArray(), r0), 1.0E-5d);
                checkComponents(createSvd, create);
            }
        }
    }

    public void testIdentity() {
        Matrix createIdentity = Matrix.createIdentity(6, 6);
        SvdImplicitQr createSvd = createSvd();
        Assert.assertNotNull(createSvd._decompose(createIdentity));
        Assert.assertEquals(6.0d, checkOccurrence(1.0d, createSvd.getSingularValues().toDoubleArray(), 6), 1.0E-5d);
        checkComponents(createSvd, createIdentity);
    }

    public void testVerySmallValue() {
        Matrix createRandom = Matrix.createRandom(5, 5);
        createRandom.sub(0.5d);
        createRandom.scale(2.0d);
        createRandom.scale(1.0E-200d);
        SvdImplicitQr createSvd = createSvd();
        Assert.assertNotNull(createSvd._decompose(createRandom));
        checkComponents(createSvd, createRandom);
    }

    public void testLots() {
        SvdImplicitQr createSvd = createSvd();
        for (int i = 1; i < 8; i += 2) {
            for (int i2 = 1; i2 < 8; i2 += 2) {
                Matrix createRandom = Matrix.createRandom(i, i2);
                createRandom.sub(0.5d);
                createRandom.scale(2.0d);
                Assert.assertNotNull(createSvd._decompose(createRandom));
                checkComponents(createSvd, createRandom);
            }
        }
    }

    public void checkGetU_Transpose() {
        Matrix createRandom = Matrix.createRandom(5, 7);
        createRandom.sub(0.5d);
        createRandom.scale(2.0d);
        SvdImplicitQr createSvd = createSvd();
        Assert.assertNotNull(createSvd._decompose(createRandom));
        Matrix matrix = createSvd.getU().toMatrix();
        Assert.assertArrayEquals(createSvd.getU().getTranspose().toMatrix().getElements(), matrix.getTransposeCopy().toMatrix().getElements(), 1.0E-6d);
    }

    public void testLargeToSmall() {
        SvdImplicitQr createSvd = createSvd();
        Matrix createRandom = Matrix.createRandom(10, 10);
        createRandom.sub(0.5d);
        createRandom.scale(2.0d);
        Assert.assertNotNull(createSvd._decompose(createRandom));
        checkComponents(createSvd, createRandom);
        Matrix createRandom2 = Matrix.createRandom(5, 5);
        createRandom2.sub(0.5d);
        createRandom2.scale(2.0d);
        Assert.assertNotNull(createSvd._decompose(createRandom2));
        checkComponents(createSvd, createRandom2);
    }

    private int checkOccurrence(double d, double[] dArr, int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            if (Math.abs(dArr[i3] - d) < 1.0E-8d) {
                i2++;
            }
        }
        return i2;
    }

    private void checkComponents(SvdImplicitQr svdImplicitQr, Matrix matrix) {
        Matrix matrix2 = svdImplicitQr.getU().toMatrix();
        Matrix matrix3 = svdImplicitQr.getV().getTranspose().toMatrix();
        Matrix matrix4 = svdImplicitQr.getS().toMatrix();
        Assert.assertTrue(!matrix2.hasUncountable());
        Assert.assertTrue(!matrix3.hasUncountable());
        Assert.assertTrue(!matrix4.hasUncountable());
        if (svdImplicitQr.isCompact()) {
            Assert.assertTrue(matrix4.columnCount() == matrix4.rowCount());
            Assert.assertTrue(matrix2.columnCount() == matrix4.rowCount());
            Assert.assertTrue(matrix3.rowCount() == matrix4.columnCount());
        } else {
            Assert.assertTrue(matrix2.columnCount() == matrix4.rowCount());
            Assert.assertTrue(matrix4.columnCount() == matrix3.rowCount());
            Assert.assertTrue(matrix2.columnCount() == matrix2.rowCount());
            Assert.assertTrue(matrix3.columnCount() == matrix3.rowCount());
        }
        Assert.assertArrayEquals(matrix.toDoubleArray(), Multiplications.multiply(matrix2, Multiplications.multiply(matrix4, matrix3)).toDoubleArray(), 1.0E-6d);
    }

    private static int rank(SvdImplicitQr svdImplicitQr, double d) {
        int i = 0;
        double[] doubleArray = svdImplicitQr.getSingularValues().toDoubleArray();
        int numberOfSingularValues = svdImplicitQr.numberOfSingularValues();
        for (int i2 = 0; i2 < numberOfSingularValues; i2++) {
            if (doubleArray[i2] > d) {
                i++;
            }
        }
        return i;
    }

    public static int nullity(SvdImplicitQr svdImplicitQr, double d) {
        int i = 0;
        double[] doubleArray = svdImplicitQr.getSingularValues().toDoubleArray();
        int numberOfSingularValues = svdImplicitQr.numberOfSingularValues();
        int numCols = svdImplicitQr.numCols();
        for (int i2 = 0; i2 < numberOfSingularValues; i2++) {
            if (doubleArray[i2] <= d) {
                i++;
            }
        }
        return (i + numCols) - numberOfSingularValues;
    }

    private static void checkNumFound(int i, double d, double d2, double[] dArr) {
        int i2 = 0;
        for (double d3 : dArr) {
            if (Math.abs(d3 - d2) <= d) {
                i2++;
            }
        }
        Assert.assertEquals(i, i2);
    }
}
