package us.ihmc.math.linearAlgebra.careSolvers;

import java.util.ArrayList;
import java.util.List;
import org.ejml.EjmlUnitTests;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import us.ihmc.matrixlib.NativeCommonOps;

/* loaded from: input_file:us/ihmc/math/linearAlgebra/careSolvers/CARESolversTest.class */
public class CARESolversTest {
    private static final double epsilon = 1.0E-4d;

    private List<CARESolver> getSolvers() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new EigenvectorCARESolver());
        arrayList.add(new NewtonCARESolver(new EigenvectorCARESolver()));
        arrayList.add(new Newton2CARESolver(new EigenvectorCARESolver()));
        arrayList.add(new DefectCorrectionCARESolver(new EigenvectorCARESolver()));
        arrayList.add(new SignFunctionCARESolver());
        arrayList.add(new NewtonCARESolver(new SignFunctionCARESolver()));
        arrayList.add(new Newton2CARESolver(new SignFunctionCARESolver()));
        arrayList.add(new DefectCorrectionCARESolver(new SignFunctionCARESolver()));
        return arrayList;
    }

    @Test
    public void testSimple() {
        for (CARESolver cARESolver : getSolvers()) {
            DMatrixRMaj identity = CommonOps_DDRM.identity(2);
            DMatrixRMaj identity2 = CommonOps_DDRM.identity(2);
            DMatrixRMaj identity3 = CommonOps_DDRM.identity(2);
            DMatrixRMaj identity4 = CommonOps_DDRM.identity(2);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(identity);
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(identity2);
            DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(identity3);
            DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(identity4);
            cARESolver.setMatrices(identity, identity2, CommonOps_DDRM.identity(2), CommonOps_DDRM.identity(2), identity3, identity4, (DMatrixRMaj) null);
            cARESolver.computeP();
            DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.multTransA(identity, cARESolver.getP(), dMatrixRMaj5);
            CommonOps_DDRM.multAdd(cARESolver.getP(), identity, dMatrixRMaj5);
            DMatrixRMaj dMatrixRMaj6 = new DMatrixRMaj(2, 2);
            NativeCommonOps.invert(identity4, dMatrixRMaj6);
            DMatrixRMaj dMatrixRMaj7 = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.multTransA(identity2, cARESolver.getP(), dMatrixRMaj7);
            DMatrixRMaj dMatrixRMaj8 = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.mult(identity2, dMatrixRMaj6, dMatrixRMaj8);
            DMatrixRMaj dMatrixRMaj9 = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.mult(cARESolver.getP(), dMatrixRMaj8, dMatrixRMaj9);
            DMatrixRMaj dMatrixRMaj10 = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.mult(dMatrixRMaj9, dMatrixRMaj7, dMatrixRMaj10);
            CommonOps_DDRM.addEquals(dMatrixRMaj5, -1.0d, dMatrixRMaj10);
            CommonOps_DDRM.scale(-1.0d, dMatrixRMaj5);
            EjmlUnitTests.assertEquals(dMatrixRMaj, identity, epsilon);
            EjmlUnitTests.assertEquals(dMatrixRMaj2, identity2, epsilon);
            EjmlUnitTests.assertEquals(dMatrixRMaj3, identity3, epsilon);
            EjmlUnitTests.assertEquals(dMatrixRMaj4, identity4, epsilon);
            assertIsSymmetric(cARESolver.getP(), epsilon);
            assertSolutionIsValid(dMatrixRMaj, dMatrixRMaj2, dMatrixRMaj3, dMatrixRMaj4, cARESolver.getP(), epsilon);
        }
    }

    @Test
    public void testMatlabCare() {
        for (CARESolver cARESolver : getSolvers()) {
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(2, 2);
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(2, 1);
            DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(1, 2);
            DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(2, 2);
            DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(1, 1);
            dMatrixRMaj.set(0, 0, -3.0d);
            dMatrixRMaj.set(0, 1, 2.0d);
            dMatrixRMaj.set(1, 0, 1.0d);
            dMatrixRMaj.set(1, 1, 1.0d);
            dMatrixRMaj2.set(1, 0, 1.0d);
            dMatrixRMaj3.set(0, 0, 1.0d);
            dMatrixRMaj3.set(0, 1, -1.0d);
            dMatrixRMaj5.set(0, 0, 3.0d);
            CommonOps_DDRM.multInner(dMatrixRMaj3, dMatrixRMaj4);
            cARESolver.setMatrices(dMatrixRMaj, dMatrixRMaj2, CommonOps_DDRM.identity(2), CommonOps_DDRM.identity(2), dMatrixRMaj4, dMatrixRMaj5, (DMatrixRMaj) null);
            cARESolver.computeP();
            DMatrixRMaj dMatrixRMaj6 = new DMatrixRMaj(2, 2);
            dMatrixRMaj6.set(0, 0, 0.5895d);
            dMatrixRMaj6.set(0, 1, 1.8216d);
            dMatrixRMaj6.set(1, 0, 1.8216d);
            dMatrixRMaj6.set(1, 1, 8.8188d);
            DMatrixRMaj p = cARESolver.getP();
            assertSolutionIsValid(dMatrixRMaj, dMatrixRMaj2, dMatrixRMaj4, dMatrixRMaj5, p, epsilon);
            EjmlUnitTests.assertEquals(dMatrixRMaj6, p, epsilon);
        }
    }

    @Test
    public void testMatlabCare2() {
        for (CARESolver cARESolver : getSolvers()) {
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(3, 3);
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(3, 1);
            DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(1, 3);
            DMatrixRMaj identity = CommonOps_DDRM.identity(3);
            DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(3, 3);
            DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(1, 1);
            dMatrixRMaj.set(0, 0, 1.0d);
            dMatrixRMaj.set(0, 1, -2.0d);
            dMatrixRMaj.set(0, 2, 3.0d);
            dMatrixRMaj.set(1, 0, -4.0d);
            dMatrixRMaj.set(1, 1, 5.0d);
            dMatrixRMaj.set(1, 2, 6.0d);
            dMatrixRMaj.set(2, 0, 7.0d);
            dMatrixRMaj.set(2, 1, 8.0d);
            dMatrixRMaj.set(2, 2, 9.0d);
            dMatrixRMaj2.set(0, 0, 5.0d);
            dMatrixRMaj2.set(1, 0, 6.0d);
            dMatrixRMaj2.set(2, 0, -7.0d);
            dMatrixRMaj3.set(0, 0, 7.0d);
            dMatrixRMaj3.set(0, 1, -8.0d);
            dMatrixRMaj3.set(0, 2, 9.0d);
            dMatrixRMaj5.set(0, 0, 1.0d);
            CommonOps_DDRM.multInner(dMatrixRMaj3, dMatrixRMaj4);
            cARESolver.setMatrices(dMatrixRMaj, dMatrixRMaj2, CommonOps_DDRM.identity(3), identity, dMatrixRMaj4, dMatrixRMaj5, (DMatrixRMaj) null);
            cARESolver.computeP();
            DMatrixRMaj dMatrixRMaj6 = new DMatrixRMaj(1, 1);
            DMatrixRMaj dMatrixRMaj7 = new DMatrixRMaj(1, 3);
            DMatrixRMaj dMatrixRMaj8 = new DMatrixRMaj(3, 3);
            NativeCommonOps.invert(dMatrixRMaj5, dMatrixRMaj6);
            CommonOps_DDRM.transpose(dMatrixRMaj2, dMatrixRMaj7);
            NativeCommonOps.multQuad(dMatrixRMaj7, dMatrixRMaj6, dMatrixRMaj8);
            assertSolutionIsValid(dMatrixRMaj, dMatrixRMaj2, dMatrixRMaj4, dMatrixRMaj5, cARESolver.getP(), epsilon);
        }
    }

    private static void assertIsSymmetric(DMatrixRMaj dMatrixRMaj, double d) {
        for (int i = 0; i < dMatrixRMaj.getNumRows(); i++) {
            for (int i2 = 0; i2 < dMatrixRMaj.getNumCols(); i2++) {
                Assertions.assertEquals(dMatrixRMaj.get(i, i2), dMatrixRMaj.get(i2, i), d, "Not symmetric!");
            }
        }
    }

    static void assertSolutionIsValid(DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2, DMatrixRMaj dMatrixRMaj3, DMatrixRMaj dMatrixRMaj4, DMatrixRMaj dMatrixRMaj5, double d) {
        int numRows = dMatrixRMaj.getNumRows();
        int numCols = dMatrixRMaj2.getNumCols();
        DMatrixRMaj dMatrixRMaj6 = new DMatrixRMaj(numRows, numRows);
        DMatrixRMaj dMatrixRMaj7 = new DMatrixRMaj(numCols, numCols);
        DMatrixRMaj dMatrixRMaj8 = new DMatrixRMaj(numCols, numRows);
        CommonOps_DDRM.transpose(dMatrixRMaj2, dMatrixRMaj8);
        CARETools.computeM(dMatrixRMaj8, dMatrixRMaj4, (DMatrixRMaj) null, dMatrixRMaj7);
        CARETools.computeRiccatiRate(dMatrixRMaj5, dMatrixRMaj, dMatrixRMaj3, dMatrixRMaj7, dMatrixRMaj6);
        EjmlUnitTests.assertEquals(new DMatrixRMaj(numRows, numRows), dMatrixRMaj6, d);
    }
}
