package cc.redberry.core.transformations.collect;

import cc.redberry.core.TAssert;
import cc.redberry.core.combinatorics.Combinatorics;
import cc.redberry.core.context.CC;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.tensor.iterator.FromChildToParentIterator;
import cc.redberry.core.transformations.EliminateMetricsTransformation;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.transformations.expand.ExpandTransformation;
import cc.redberry.core.transformations.factor.FactorTransformation;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;

/* loaded from: input_file:cc/redberry/core/transformations/collect/CollectTransformationTest.class */
public class CollectTransformationTest {
    @Test
    public void test1() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("a")}).transform(Tensors.parse("a*b + a*c")), "a*(b+c)");
    }

    @Test
    public void test2() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("a"), Tensors.parseSimple("b")}).transform(Tensors.parse("a*b + a*c + a*d + b*e + b*r")), "a*b + a*(c+d) + b*(e+r)");
    }

    @Test
    @Ignore
    public void test3() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("A_m")}).transform(Tensors.parse("A_m*B_n + A_m*C_n")), "A_m*(B_n + C_n)");
    }

    @Test
    public void test4() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("A_m")}).transform(Tensors.parse("A_m*B_n + A_n*C_m")), "A_i*(d^i_m*B_n + d^i_n*C_m)");
    }

    @Test
    public void test5() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("A_mn")}).transform(Tensors.parse("A_mq*B_n^q + A_nq*C_m^q")), "A_iq*(d^i_m*B_n^q + d^i_n*C_m^q)");
    }

    @Test
    public void test6() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("A_mn")}).transform(Tensors.parse("A_mq*B_n^q + A_nq*C_m^q")), "A_iq*(d^i_m*B_n^q + d^i_n*C_m^q)");
    }

    @Test
    public void test7() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("A_mn")}).transform(Tensors.parse("A_mq*B_n^q + A_qn*C_m^q")), "A_iq*(d^i_m*B_n^q + d^q_n*C_m^i)");
    }

    @Test
    public void test8() {
        CC.resetTensorNames(8816281755326274707L);
        CollectTransformation collectTransformation = new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("A_mn")});
        Tensor parse = Tensors.parse("A_mq*B_n^q + A^q_n*C_mq");
        System.out.println(parse);
        TAssert.assertEquals(collectTransformation.transform(parse), "A_iq*(d^i_m*B_n^q + d^q_n*C_m^i)");
    }

    @Test
    public void test9() {
        SimpleTensor[] simpleTensorArr = {Tensors.parseSimple("A_mnpq")};
        assertCollectExpand(Tensors.parse("A_mnpq*B^np_ac + A_abcd*B^ndb_nmq"), simpleTensorArr);
        assertCollectExpand(Tensors.parse("A_mnpq*B^np_ac + A_acmq "), simpleTensorArr);
        assertCollectExpand(Tensors.parse("A_mnpq*B^np_ac + A_abcd*B^ndb_nmq + A_acmq + A_amqc + A_rsmq*C^rs_ac"), simpleTensorArr);
    }

    @Test
    public void test10() {
        Tensor transform = Tensors.parseExpression("G_gmn=(1/2)*(p_m*h_gn+p_n*h_gm-p_g*h_mn)").transform(Tensors.parseExpression("R^a_bmn=p_m*G^a_bn+p_n*G^a_bm+G^a_gm*G^g_bn-G^a_gn*G^g_bm").transform(Tensors.parseExpression("R_{mn}=g^ab*R_{bman}").transform(Tensors.parse("g_{mn}*R^{mn}"))));
        assertCollectExpand(transform, new SimpleTensor[]{Tensors.parseSimple("h_ab")});
        assertCollectExpand(transform, new SimpleTensor[]{Tensors.parseSimple("p_a")});
        assertCollectExpand(transform, new SimpleTensor[]{Tensors.parseSimple("p_a"), Tensors.parseSimple("h_ab")});
    }

    @Test
    public void test11() {
        Tensor transform = Tensors.parseExpression("Gf^a_mn[r^mn]=(1/2)*r^ag*(p_m*r_gn[x_a]+p_n*r_gm[x_z]-p_g*r_mn[x_z])").transform(Tensors.parseExpression("Rf^a_bmn[g^pq]=p_m*Gf^a_bn[g_ab]+p_n*Gf^a_bm[g_ab]+Gf^a_gm[g_ab]*Gf^g_bn[g_ab]-Gf^a_gn[g_ab]*Gf^g_bm[g_ab]").transform(Tensors.parseExpression("Rf_{mn}[g^mn]=Rf^{a}_{man}[g_pq]").transform(Tensors.parseExpression("Rf[g_ab]=g^ab*Rf_ab[g_mn]").transform(Tensors.parse("Rf[h_mn]")))));
        assertCollectExpand(transform, new SimpleTensor[]{Tensors.parseSimple("r_ab[x_a]")});
        assertCollectExpand(transform, new SimpleTensor[]{Tensors.parseSimple("p_a")});
        assertCollectExpand(transform, new SimpleTensor[]{Tensors.parseSimple("p_a"), Tensors.parseSimple("r_ab[x_a]")});
    }

    @Test
    public void test12() {
        assertCollectExpand(Tensors.parseExpression("Gf_a=f1*h^b_a*p_b+f2*g^pq*g_ab*h^b_q*p_p+f3*h^q_q*p_a").transform(Tensors.parseExpression("sqrt=1+h^a_a+(1/2)*(h^s_s*h^l_l-h^s_l*h^l_s)").transform(Tensors.parseExpression("G^ab=g^ab-g^ca*h^b_c-g^cb*h^a_c+g^cb*h^a_d*h^d_c+g^ca*h^b_d*h^d_c+g^cd*h^a_c*h^b_d").transform(Tensors.parseExpression("E^a_b=d^a_b-h^a_b+h^a_c*h^c_b").transform(Tensors.parseExpression("e^a_b=d^a_b+h^a_b").transform(Tensors.parseExpression("T^a_bc=i*h^a_c*p_b-i*h^a_b*p_c+w^a_bd*e^d_c-w^a_cd*e^d_b").transform(Tensors.parseExpression("R^a_bcd=i*w^a_db*p_c-i*w^a_cb*p_d+w^a_cr*w^r_db-w^a_dr*w^r_cb").transform(Tensors.parseExpression("Ric_ab=E^r_a*E^d_c*R^c_bdr").transform(Tensors.parse("sqrt*(g^ab*Ric_ab+(e1*g_ab*G^xp*G^yq+e2*E^x_a*E^p_b*G^yq+e3*E^x_b*E^p_a*G^yq)*T^a_xy*T^b_pq+e6*Ric_ab*Ric_cd*g^ab*g^cd+e5*Ric_ab*Ric_cd*g^ac*g^bd+Gf_a*Gf_b*g^ab)+f*g^pq*g_ab*i*h^b_q*p_p*g^cd*I*h^a_d*p_c"))))))))), new SimpleTensor[]{Tensors.parseSimple("h^a_b"), Tensors.parseSimple("w^a_bc")});
    }

    @Test
    public void test13() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("f[x]")}).transform(Tensors.parse("a*f[x]*f[-x] + b*f[x]*f[-x] + x*f[x]*f[y] + y*f[y]*f[x]")), "(y+x)*f[x]*f[y]+(a+b)*f[-x]*f[x]");
    }

    private static void assertCollectExpand(Tensor tensor, SimpleTensor[] simpleTensorArr) {
        Tensor eliminate = EliminateMetricsTransformation.eliminate(ExpandTransformation.expand(tensor));
        Tensor transform = new CollectTransformation(simpleTensorArr).transform(eliminate);
        if (transform instanceof Sum) {
            Iterator it = transform.iterator();
            while (it.hasNext()) {
                assertCollectedSummand((Tensor) it.next(), simpleTensorArr);
            }
        } else {
            assertCollectedSummand(transform, simpleTensorArr);
        }
        TAssert.assertEquals(EliminateMetricsTransformation.eliminate(ExpandTransformation.expand(transform, new Transformation[]{EliminateMetricsTransformation.ELIMINATE_METRICS})), eliminate);
    }

    private static void assertCollectedSummand(Tensor tensor, SimpleTensor[] simpleTensorArr) {
        if (tensor instanceof Product) {
            Iterator it = tensor.iterator();
            while (it.hasNext()) {
                Tensor tensor2 = (Tensor) it.next();
                if (tensor2 instanceof Sum) {
                    FromChildToParentIterator fromChildToParentIterator = new FromChildToParentIterator(tensor2);
                    while (true) {
                        SimpleTensor next = fromChildToParentIterator.next();
                        if (next != null) {
                            if (next instanceof SimpleTensor) {
                                for (SimpleTensor simpleTensor : simpleTensorArr) {
                                    Assert.assertFalse(next.getName() == simpleTensor.getName());
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    @Test
    public void testMatch() {
        for (int i = 0; i < 100; i++) {
            Random random = new Random();
            CC.resetTensorNames();
            SimpleTensor[] simpleTensorArr = {Tensors.parseSimple("f_a[-x_a-y_a]"), Tensors.parseSimple("f_c[x_b]"), Tensors.parseSimple("f_d[y_d]"), Tensors.parseSimple("g[x]"), Tensors.parseSimple("g[-f-x]"), Tensors.parseSimple("g[f]")};
            SimpleTensor[] simpleTensorArr2 = (SimpleTensor[]) simpleTensorArr.clone();
            Combinatorics.shuffle(simpleTensorArr2, random);
            Arrays.sort(simpleTensorArr);
            Arrays.sort(simpleTensorArr2);
            Assert.assertArrayEquals(simpleTensorArr, Combinatorics.reorder(simpleTensorArr2, CollectTransformation.matchFactors(simpleTensorArr, simpleTensorArr2)));
        }
    }

    @Test
    public void testDerivatives1() {
        Tensor parse = Tensors.parse("f~(1)[x] + f[x]");
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("f[x]")}).transform(parse), parse);
    }

    @Test
    public void testDerivatives2() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("f~(0,1)[x, y]"), Tensors.parseSimple("f~(1,0)[x, y]")}, new Transformation[]{FactorTransformation.FACTOR}).transform(Tensors.parse("D[x][x*f[x, x**2] + f[x, x**2]]")), "f[x,x**2]+(x+1)*f~(1,0)[x,x**2]+2*x*(x+1)*f~(0,1)[x,x**2]");
    }

    @Test
    public void testPower1() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("x")}).transform(Tensors.parse("x**2 + x**2*a")), "x**2*(1+a)");
    }

    @Test
    public void testPower2() {
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("x"), Tensors.parseSimple("y")}).transform(Tensors.parse("y**3*x**2*b*c + y**3*x**2*a**2")), "y**3*x**2*(b*c+a**2)");
    }

    @Test
    public void testPower3() {
        TAssert.assertEquals(EliminateMetricsTransformation.eliminate(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("A_m")}).transform(Tensors.parse("(A_m*A^m*c)**2 + A_m*A^m*A_i*A^i"))), "A_m*A^m*A_i*A^i*(c**2 + 1)");
    }

    @Test
    public void testPower4() {
        Tensor parse = Tensors.parse("x**2*y**3*(a + b + c) + x*y*(c + d) + x*(a+b) + y*(c+e) + r");
        TAssert.assertEquals(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("x"), Tensors.parseSimple("y")}).transform(ExpandTransformation.expand(parse)), parse);
    }

    @Test
    public void testPower5() {
        Tensor parse = Tensors.parse("x_m*y_n*x_a*(a^a + b^a + c^a) + x_m*y_n*(c + d) + x_m*(a_n+b_n) + y_n*(c_m+e_m) + r_mn");
        TAssert.assertEquals(EliminateMetricsTransformation.eliminate(ExpandTransformation.expand(EliminateMetricsTransformation.eliminate(new CollectTransformation(new SimpleTensor[]{Tensors.parseSimple("x_m"), Tensors.parseSimple("y_m")}).transform(ExpandTransformation.expand(parse))))), ExpandTransformation.expand(parse));
    }
}
