package org.flinkextended.flink.ml.cluster.node.runner;

import java.util.Map;
import java.util.concurrent.FutureTask;
import org.flinkextended.flink.ml.cluster.ExecutionMode;
import org.flinkextended.flink.ml.cluster.MLConfig;
import org.flinkextended.flink.ml.cluster.node.MLContext;
import org.flinkextended.flink.ml.cluster.role.AMRole;
import org.flinkextended.flink.ml.cluster.rpc.AppMasterServer;
import org.flinkextended.flink.ml.cluster.rpc.NodeServer;
import org.flinkextended.flink.ml.proto.NodeSpec;
import org.flinkextended.flink.ml.util.DummyContext;
import org.flinkextended.flink.ml.util.MLConstants;
import org.flinkextended.flink.ml.util.MLException;
import org.hamcrest.CoreMatchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/flinkextended/flink/ml/cluster/node/runner/CommonMLRunnerTest.class */
public class CommonMLRunnerTest {
    private static FutureTask<Void> amFuture;
    private static AppMasterServer amServer;
    private CommonMLRunner mlRunner;
    private NodeServer nodeServer;
    private static MLConfig mlConfig;
    private MLContext mlContext;

    @Before
    public void setUp() throws Exception {
        mlConfig = DummyContext.createDummyMLConfig();
        startAMServer(mlConfig);
        this.nodeServer = (NodeServer) Mockito.mock(NodeServer.class);
        this.mlContext = DummyContext.createDummyMLContext();
        this.mlContext.getProperties().put("script_runner_class", TestScriptRunner.class.getCanonicalName());
        this.mlRunner = (CommonMLRunner) Mockito.spy(new CommonMLRunner(this.mlContext, this.nodeServer));
        this.mlRunner.initAMClient();
        Assert.assertNotNull(this.mlRunner.amClient);
    }

    @After
    public void tearDown() throws Exception {
        amServer.setEnd(true);
        amFuture.get();
    }

    @Test
    public void testGetCurrentJobVersion() {
        this.mlRunner.getCurrentJobVersion();
        Assert.assertTrue(this.mlRunner.version > 0);
    }

    @Test
    public void testGetTaskIndex() throws MLException, InterruptedException {
        this.mlRunner.getCurrentJobVersion();
        this.mlRunner.mlContext.setIndex(-1);
        this.mlRunner.getTaskIndex();
        Assert.assertEquals(0L, this.mlRunner.mlContext.getIndex());
    }

    @Test
    public void testRegisterNode() throws Exception {
        this.mlRunner.registerNode();
        Assert.assertEquals(this.mlContext.getRoleName(), ((NodeSpec) this.mlRunner.amClient.getClusterInfo(this.mlRunner.version).getClusterDef().getJob(0).getTasksMap().get(0)).getRoleName());
        Assert.assertEquals(this.mlContext.getIndex(), r0.getIndex());
    }

    @Test
    public void testStartHeartbeat() throws Exception {
        Assert.assertNull(this.mlRunner.getHeartBeatRunnerFuture());
        this.mlRunner.startHeartBeat();
        Assert.assertNotNull(this.mlRunner.getHeartBeatRunnerFuture());
        Assert.assertFalse(this.mlRunner.getHeartBeatRunnerFuture().isDone());
    }

    @Test
    public void testWaitClusterRunning() throws Exception {
        MLConstants.TIMEOUT = 1000L;
        try {
            this.mlRunner.waitClusterRunning();
        } catch (MLException e) {
        }
        this.mlRunner.registerNode();
        this.mlRunner.startHeartBeat();
        this.mlRunner.waitClusterRunning();
    }

    @Test
    public void testGetClusterInfo() throws Exception {
        MLConstants.TIMEOUT = 1000L;
        this.mlRunner.getCurrentJobVersion();
        this.mlRunner.getClusterInfo();
        Assert.assertNull(this.mlRunner.mlClusterDef);
        this.mlRunner.registerNode();
        this.mlRunner.getClusterInfo();
        Assert.assertNotNull(this.mlRunner.mlClusterDef);
    }

    @Test
    public void testResetMLContext() throws Exception {
        this.mlRunner.getCurrentJobVersion();
        this.mlRunner.registerNode();
        this.mlRunner.getClusterInfo();
        Assert.assertNull(this.mlRunner.mlContext.getProperties().get("cluster"));
        this.mlRunner.resetMLContext();
        Assert.assertNotNull(this.mlRunner.mlContext.getProperties().get("cluster"));
    }

    @Test
    public void testRunScript() throws Exception {
        this.mlRunner.runScript();
        Assert.assertThat(this.mlRunner.scriptRunner, CoreMatchers.instanceOf(TestScriptRunner.class));
        Assert.assertTrue(this.mlRunner.scriptRunner.isRan());
    }

    @Test
    public void testRun() throws Exception {
        this.mlRunner.run();
        ((CommonMLRunner) Mockito.verify(this.mlRunner, Mockito.atLeastOnce())).initAMClient();
        ((CommonMLRunner) Mockito.verify(this.mlRunner, Mockito.atLeastOnce())).getCurrentJobVersion();
        ((CommonMLRunner) Mockito.verify(this.mlRunner)).getTaskIndex();
        ((CommonMLRunner) Mockito.verify(this.mlRunner)).registerNode();
        ((CommonMLRunner) Mockito.verify(this.mlRunner)).startHeartBeat();
        ((CommonMLRunner) Mockito.verify(this.mlRunner)).waitClusterRunning();
        ((CommonMLRunner) Mockito.verify(this.mlRunner)).getClusterInfo();
        ((CommonMLRunner) Mockito.verify(this.mlRunner)).resetMLContext();
        ((CommonMLRunner) Mockito.verify(this.mlRunner)).runScript();
    }

    @Test
    public void testKillByFlink() throws InterruptedException {
        Thread thread = new Thread((Runnable) this.mlRunner);
        thread.start();
        while (this.mlRunner.currentResultStatus != ExecutionStatus.RUNNING) {
            Thread.sleep(100L);
        }
        this.mlRunner.notifyStop();
        thread.join();
        Assert.assertEquals(ExecutionStatus.KILLED_BY_FLINK, this.mlRunner.resultStatus);
    }

    private static FutureTask<Void> startAMServer(MLConfig mLConfig) throws MLException {
        amServer = new AppMasterServer(new MLContext(ExecutionMode.TRAIN, mLConfig, new AMRole().name(), 0, (String) null, (Map) null));
        amFuture = new FutureTask<>(amServer, null);
        Thread thread = new Thread(amFuture);
        thread.setDaemon(true);
        thread.start();
        return amFuture;
    }
}
