package com.oracle.truffle.api.dsl.test;

import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeFactory;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.dsl.internal.SpecializationNode;
import com.oracle.truffle.api.dsl.internal.SpecializedNode;
import com.oracle.truffle.api.dsl.test.MergeSpecializationsTestFactory;
import com.oracle.truffle.api.dsl.test.TypeBoxingTest;
import com.oracle.truffle.api.dsl.test.TypeSystemTest;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.test.ReflectionUtils;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:com/oracle/truffle/api/dsl/test/MergeSpecializationsTest.class */
public class MergeSpecializationsTest {
    private static final int THREADS = 50;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/oracle/truffle/api/dsl/test/MergeSpecializationsTest$Executions.class */
    public static class Executions {
        public final Object firstValue;
        public final Object secondValue;
        public final Object thirdValue;

        Executions(Object obj, Object obj2, Object obj3) {
            this.firstValue = obj;
            this.secondValue = obj2;
            this.thirdValue = obj3;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @TypeSystemReference(TypeBoxingTest.TypeBoxingTypeSystem.class)
    @NodeChild
    /* loaded from: input_file:com/oracle/truffle/api/dsl/test/MergeSpecializationsTest$TestCachedNode.class */
    public static abstract class TestCachedNode extends TypeSystemTest.ValueNode {
        /* JADX INFO: Access modifiers changed from: package-private */
        @Specialization(guards = {"a == cachedA"}, limit = "3")
        public int s1(int i, @Cached("a") int i2) {
            return 1;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Specialization
        public int s2(long j) {
            return 2;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Specialization
        public int s3(double d) {
            return 3;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @TypeSystemReference(TypeBoxingTest.TypeBoxingTypeSystem.class)
    @NodeChild
    /* loaded from: input_file:com/oracle/truffle/api/dsl/test/MergeSpecializationsTest$TestNode.class */
    public static abstract class TestNode extends TypeSystemTest.ValueNode {
        /* JADX INFO: Access modifiers changed from: package-private */
        @Specialization
        public int s1(int i) {
            return 1;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Specialization
        public int s2(long j) {
            return 2;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Specialization
        public int s3(double d) {
            return 3;
        }
    }

    @Test
    public void testMultithreadedMergeInOrder() throws Exception {
        for (int i = 0; i < 100; i++) {
            multithreadedMerge(MergeSpecializationsTestFactory.TestNodeFactory.getInstance(), new Executions(1, 4294967296L, Double.valueOf(1.0d)), 1, 2, 3);
        }
    }

    @Test
    public void testMultithreadedMergeReverse() throws Exception {
        for (int i = 0; i < 100; i++) {
            multithreadedMerge(MergeSpecializationsTestFactory.TestNodeFactory.getInstance(), new Executions(Double.valueOf(1.0d), 4294967296L, 1), 3, 2, 1);
        }
    }

    @Test
    public void testMultithreadedMergeCachedInOrder() throws Exception {
        for (int i = 0; i < 100; i++) {
            multithreadedMerge(MergeSpecializationsTestFactory.TestCachedNodeFactory.getInstance(), new Executions(1, 4294967296L, Double.valueOf(1.0d)), 1, 2, 3);
        }
    }

    @Test
    public void testMultithreadedMergeCachedTwoEntries() throws Exception {
        for (int i = 0; i < 100; i++) {
            multithreadedMerge(MergeSpecializationsTestFactory.TestCachedNodeFactory.getInstance(), new Executions(1, 2, Double.valueOf(1.0d)), 1, 1, 3);
        }
    }

    @Test
    public void testMultithreadedMergeCachedThreeEntries() throws Exception {
        for (int i = 0; i < 100; i++) {
            multithreadedMerge(MergeSpecializationsTestFactory.TestCachedNodeFactory.getInstance(), new Executions(1, 2, 3), 1, 1, 1);
        }
    }

    private static <T extends TypeSystemTest.ValueNode> void multithreadedMerge(NodeFactory<T> nodeFactory, final Executions executions, int... iArr) throws Exception {
        Assert.assertEquals(3L, iArr.length);
        final TypeSystemTest.TestRootNode createRoot = TestHelper.createRoot(nodeFactory, new Object[0]);
        final CountDownLatch countDownLatch = new CountDownLatch(THREADS);
        final CountDownLatch countDownLatch2 = new CountDownLatch(1);
        final CountDownLatch countDownLatch3 = new CountDownLatch(THREADS);
        final CountDownLatch countDownLatch4 = new CountDownLatch(1);
        final CountDownLatch countDownLatch5 = new CountDownLatch(THREADS);
        final CountDownLatch countDownLatch6 = new CountDownLatch(1);
        final CountDownLatch countDownLatch7 = new CountDownLatch(THREADS);
        Thread[] threadArr = new Thread[THREADS];
        for (int i = 0; i < threadArr.length; i++) {
            threadArr[i] = new Thread(new Runnable() { // from class: com.oracle.truffle.api.dsl.test.MergeSpecializationsTest.1
                @Override // java.lang.Runnable
                public void run() {
                    countDownLatch.countDown();
                    MergeSpecializationsTest.await(countDownLatch2);
                    TestHelper.executeWith(createRoot, executions.firstValue);
                    countDownLatch3.countDown();
                    MergeSpecializationsTest.await(countDownLatch4);
                    TestHelper.executeWith(createRoot, executions.secondValue);
                    countDownLatch5.countDown();
                    MergeSpecializationsTest.await(countDownLatch6);
                    TestHelper.executeWith(createRoot, executions.thirdValue);
                    countDownLatch7.countDown();
                }
            });
            threadArr[i].start();
        }
        SpecializedNode node = createRoot.getNode();
        if (createRoot instanceof SpecializedNode) {
            SpecializedNode specializedNode = node;
            Assert.assertEquals("UninitializedNode_", specializedNode.getSpecializationNode().getClass().getSimpleName());
            await(countDownLatch);
            countDownLatch2.countDown();
            await(countDownLatch3);
            SpecializationNode specializationNode = specializedNode.getSpecializationNode();
            Assert.assertEquals("S" + iArr[0] + "Node_", specializationNode.getClass().getSimpleName());
            Assert.assertEquals("UninitializedNode_", nthChild(1, specializationNode).getClass().getSimpleName());
            countDownLatch4.countDown();
            await(countDownLatch5);
            SpecializationNode specializationNode2 = specializedNode.getSpecializationNode();
            Arrays.sort(iArr, 0, 2);
            Assert.assertEquals("PolymorphicNode_", specializationNode2.getClass().getSimpleName());
            Assert.assertEquals("S" + iArr[0] + "Node_", nthChild(1, specializationNode2).getClass().getSimpleName());
            Assert.assertEquals("S" + iArr[1] + "Node_", nthChild(2, specializationNode2).getClass().getSimpleName());
            Assert.assertEquals("UninitializedNode_", nthChild(3, specializationNode2).getClass().getSimpleName());
            countDownLatch6.countDown();
            await(countDownLatch7);
            SpecializationNode specializationNode3 = specializedNode.getSpecializationNode();
            Arrays.sort(iArr);
            Assert.assertEquals("PolymorphicNode_", specializationNode3.getClass().getSimpleName());
            Assert.assertEquals("S" + iArr[0] + "Node_", nthChild(1, specializationNode3).getClass().getSimpleName());
            Assert.assertEquals("S" + iArr[1] + "Node_", nthChild(2, specializationNode3).getClass().getSimpleName());
            Assert.assertEquals("S" + iArr[2] + "Node_", nthChild(3, specializationNode3).getClass().getSimpleName());
            Assert.assertEquals("UninitializedNode_", nthChild(4, specializationNode3).getClass().getSimpleName());
        } else {
            assertState(node, iArr, 0);
            await(countDownLatch);
            countDownLatch2.countDown();
            await(countDownLatch3);
            assertState(node, iArr, 1);
            countDownLatch4.countDown();
            await(countDownLatch5);
            assertState(node, iArr, 2);
            countDownLatch6.countDown();
            await(countDownLatch7);
            assertState(node, iArr, 3);
        }
        for (Thread thread : threadArr) {
            try {
                thread.join();
            } catch (InterruptedException e) {
                Assert.fail("interrupted");
            }
        }
    }

    private static void assertState(Node node, int[] iArr, int i) throws IllegalArgumentException, IllegalAccessException, NoSuchFieldException, SecurityException {
        Field declaredField = node.getClass().getDeclaredField("state_");
        ReflectionUtils.setAccessible(declaredField, true);
        int intValue = (((Number) declaredField.get(node)).intValue() & (-2)) >> 1;
        Arrays.sort(iArr, 0, i);
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            i2 |= 1 << (iArr[i3] - 1);
        }
        Assert.assertEquals(i2, intValue & 7);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void await(CountDownLatch countDownLatch) {
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Assert.fail("interrupted");
        }
    }

    private static Node firstChild(Node node) {
        return (Node) node.getChildren().iterator().next();
    }

    private static Node nthChild(int i, Node node) {
        return i == 0 ? node : nthChild(i - 1, firstChild(node));
    }
}
