package org.apache.hadoop.yarn.server.federation.policies.router;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.server.federation.policies.dao.WeightedPolicyInfo;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterId;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterIdInfo;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterInfo;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterState;
import org.apache.hadoop.yarn.server.federation.utils.FederationPoliciesTestUtil;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

/* JADX WARN: Classes with same name are omitted:
  input_file:hadoop-yarn-server-common-2.10.0-tests.jar:org/apache/hadoop/yarn/server/federation/policies/router/TestWeightedRandomRouterPolicy.class
 */
/* loaded from: input_file:test-classes/org/apache/hadoop/yarn/server/federation/policies/router/TestWeightedRandomRouterPolicy.class */
public class TestWeightedRandomRouterPolicy extends BaseRouterPoliciesTest {
    @Before
    public void setUp() throws Exception {
        setPolicy(new WeightedRandomRouterPolicy());
        setPolicyInfo(new WeightedPolicyInfo());
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (int i = 0; i < 20.0f; i++) {
            SubClusterIdInfo subClusterIdInfo = new SubClusterIdInfo("sc" + i);
            if (getRand().nextFloat() < 0.95f) {
                SubClusterInfo subClusterInfo = (SubClusterInfo) Mockito.mock(SubClusterInfo.class);
                Mockito.when(subClusterInfo.getState()).thenReturn(SubClusterState.SC_RUNNING);
                Mockito.when(subClusterInfo.getSubClusterId()).thenReturn(subClusterIdInfo.toId());
                getActiveSubclusters().put(subClusterIdInfo.toId(), subClusterInfo);
            }
            float nextFloat = (0.8f / 20.0f) + (0.2f * getRand().nextFloat());
            if (i <= 5 || getRand().nextFloat() > 0.05f) {
                hashMap.put(subClusterIdInfo, Float.valueOf(nextFloat));
                hashMap2.put(subClusterIdInfo, Float.valueOf(nextFloat));
            }
        }
        getPolicyInfo().setRouterPolicyWeights(hashMap);
        getPolicyInfo().setAMRMPolicyWeights(hashMap2);
        FederationPoliciesTestUtil.initializePolicyContext(getPolicy(), getPolicyInfo(), getActiveSubclusters());
    }

    @Test
    public void testClusterChosenWithRightProbability() throws YarnException {
        ApplicationSubmissionContext applicationSubmissionContext = (ApplicationSubmissionContext) Mockito.mock(ApplicationSubmissionContext.class);
        Mockito.when(applicationSubmissionContext.getQueue()).thenReturn("queue1");
        setApplicationSubmissionContext(applicationSubmissionContext);
        HashMap hashMap = new HashMap();
        Iterator<SubClusterIdInfo> it = getPolicyInfo().getRouterPolicyWeights().keySet().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next().toId(), new AtomicLong(0L));
        }
        float f = 0.0f;
        while (true) {
            float f2 = f;
            if (f2 >= 10000.0f) {
                break;
            }
            ((AtomicLong) hashMap.get(((FederationRouterPolicy) getPolicy()).getHomeSubcluster(getApplicationSubmissionContext(), null))).incrementAndGet();
            f = f2 + 1.0f;
        }
        float f3 = 0.0f;
        Iterator<SubClusterId> it2 = getActiveSubclusters().keySet().iterator();
        while (it2.hasNext()) {
            SubClusterIdInfo subClusterIdInfo = new SubClusterIdInfo(it2.next());
            if (getPolicyInfo().getRouterPolicyWeights().containsKey(subClusterIdInfo)) {
                f3 += getPolicyInfo().getRouterPolicyWeights().get(subClusterIdInfo).floatValue();
            }
        }
        for (Map.Entry entry : hashMap.entrySet()) {
            float floatValue = getPolicyInfo().getRouterPolicyWeights().get(new SubClusterIdInfo((SubClusterId) entry.getKey())).floatValue() / f3;
            float floatValue2 = ((AtomicLong) entry.getValue()).floatValue() / 10000.0f;
            if (getActiveSubclusters().containsKey(entry.getKey())) {
                Assert.assertTrue("Id " + entry.getKey() + " Actual weight: " + floatValue2 + " expected weight: " + floatValue, ((double) Math.abs(floatValue2 - floatValue)) < 0.01d);
            } else {
                Assert.assertTrue("Id " + entry.getKey() + " Actual weight: " + floatValue2 + " expected weight: " + floatValue, floatValue2 == 0.0f);
            }
        }
    }
}
