package org.flinkextended.flink.ml.cluster.rpc;

import com.google.common.util.concurrent.Futures;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.testing.GrpcCleanupRule;
import java.io.IOException;
import java.time.Duration;
import org.flinkextended.flink.ml.cluster.master.meta.AMMeta;
import org.flinkextended.flink.ml.proto.AppMasterServiceGrpc;
import org.flinkextended.flink.ml.proto.GetClusterInfoRequest;
import org.flinkextended.flink.ml.proto.GetFinishNodeResponse;
import org.flinkextended.flink.ml.proto.GetFinishedNodeRequest;
import org.flinkextended.flink.ml.proto.GetTaskIndexRequest;
import org.flinkextended.flink.ml.proto.GetTaskIndexResponse;
import org.flinkextended.flink.ml.proto.HeartBeatRequest;
import org.flinkextended.flink.ml.proto.MLClusterDef;
import org.flinkextended.flink.ml.proto.MLJobDef;
import org.flinkextended.flink.ml.proto.NodeRestartResponse;
import org.flinkextended.flink.ml.proto.NodeSpec;
import org.flinkextended.flink.ml.proto.NodeStopResponse;
import org.flinkextended.flink.ml.proto.SimpleResponse;
import org.flinkextended.flink.ml.proto.StopAllWorkerRequest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/flinkextended/flink/ml/cluster/rpc/AppMasterServiceImplTest.class */
public class AppMasterServiceImplTest {

    @Rule
    public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();
    private AppMasterServiceGrpc.AppMasterServiceBlockingStub stub;
    private AppMasterServer appMasterServer;
    private AppMasterServiceImpl appMasterService;

    @Before
    public void setUp() throws Exception {
        String generateName = InProcessServerBuilder.generateName();
        this.appMasterServer = (AppMasterServer) Mockito.mock(AppMasterServer.class);
        this.appMasterService = (AppMasterServiceImpl) Mockito.spy(new AppMasterServiceImpl(this.appMasterServer, 2, Duration.ofMinutes(1L)));
        this.grpcCleanupRule.register(InProcessServerBuilder.forName(generateName).directExecutor().addService(this.appMasterService).build().start());
        this.stub = AppMasterServiceGrpc.newBlockingStub(this.grpcCleanupRule.register(InProcessChannelBuilder.forName(generateName).directExecutor().build()));
    }

    @Test
    public void testHeartBeatNode() {
        SimpleResponse heartBeatNode = this.stub.heartBeatNode(HeartBeatRequest.newBuilder().build());
        ((AppMasterServer) Mockito.verify(this.appMasterServer)).updateRpcLastContact();
        Assert.assertEquals(RpcCode.OK.ordinal(), heartBeatNode.getCode());
    }

    @Test
    public void testHearBeatNodeWithVersionError() {
        Assert.assertEquals(RpcCode.VERSION_ERROR.ordinal(), this.stub.heartBeatNode(HeartBeatRequest.newBuilder().setVersion(1L).build()).getCode());
    }

    @Test
    public void testRestartNode() throws Exception {
        NodeClient nodeClient = (NodeClient) Mockito.mock(NodeClient.class);
        Mockito.when(nodeClient.restartNode()).thenReturn(Futures.immediateFuture(NodeRestartResponse.newBuilder().build()));
        NodeSpec build = NodeSpec.newBuilder().setRoleName("worker").setIndex(0).build();
        this.appMasterService.updateNodeClient(AppMasterServer.getNodeClientKey(build), nodeClient);
        this.appMasterService.restartNode(build);
        ((NodeClient) Mockito.verify(nodeClient)).restartNode();
        ((AppMasterServiceImpl) Mockito.verify(this.appMasterService)).stopHeartBeatMonitorNode(Mockito.anyString());
    }

    @Test
    public void testStopNode() throws Exception {
        NodeClient nodeClient = (NodeClient) Mockito.mock(NodeClient.class);
        Mockito.when(nodeClient.stopNode()).thenReturn(Futures.immediateFuture(NodeStopResponse.newBuilder().build()));
        NodeSpec build = NodeSpec.newBuilder().setRoleName("worker").setIndex(0).build();
        this.appMasterService.updateNodeClient(AppMasterServer.getNodeClientKey(build), nodeClient);
        this.appMasterService.stopNode(build);
        ((NodeClient) Mockito.verify(nodeClient)).stopNode();
    }

    @Test
    public void testStopAllNode() {
        NodeClient nodeClient = (NodeClient) Mockito.mock(NodeClient.class);
        Mockito.when(nodeClient.stopNode()).thenReturn(Futures.immediateFuture(NodeStopResponse.newBuilder().build()));
        NodeClient nodeClient2 = (NodeClient) Mockito.mock(NodeClient.class);
        Mockito.when(nodeClient2.stopNode()).thenReturn(Futures.immediateFuture(NodeStopResponse.newBuilder().build()));
        this.appMasterService.updateNodeClient(AppMasterServer.getNodeClientKey(NodeSpec.newBuilder().setRoleName("worker").setIndex(0).build()), nodeClient);
        this.appMasterService.updateNodeClient(AppMasterServer.getNodeClientKey(NodeSpec.newBuilder().setRoleName("worker").setIndex(1).build()), nodeClient2);
        this.appMasterService.stopAllNodes();
        ((NodeClient) Mockito.verify(nodeClient)).stopNode();
        ((NodeClient) Mockito.verify(nodeClient2)).stopNode();
    }

    @Test
    public void testGetClusterInfoWithVersionError() {
        Assert.assertEquals(RpcCode.VERSION_ERROR.ordinal(), this.stub.getClusterInfo(GetClusterInfoRequest.newBuilder().setVersion(1L).build()).getCode());
    }

    @Test
    public void testStopAllWorker() {
        Assert.assertEquals(RpcCode.OK.ordinal(), this.stub.stopAllWorker(StopAllWorkerRequest.newBuilder().build()).getCode());
    }

    @Test
    public void testGetTaskIndex() {
        GetTaskIndexResponse taskIndex = this.stub.getTaskIndex(GetTaskIndexRequest.newBuilder().build());
        Assert.assertEquals(RpcCode.OK.ordinal(), taskIndex.getCode());
        Assert.assertEquals(0L, taskIndex.getIndex());
        GetTaskIndexResponse taskIndex2 = this.stub.getTaskIndex(GetTaskIndexRequest.newBuilder().setKey("key1").build());
        Assert.assertEquals(RpcCode.OK.ordinal(), taskIndex2.getCode());
        Assert.assertEquals(1L, taskIndex2.getIndex());
        GetTaskIndexResponse taskIndex3 = this.stub.getTaskIndex(GetTaskIndexRequest.newBuilder().setScope("scope1").setKey("key1").build());
        Assert.assertEquals(RpcCode.OK.ordinal(), taskIndex3.getCode());
        Assert.assertEquals(0L, taskIndex3.getIndex());
    }

    @Test
    public void testGetTaskIndexWithVersionError() {
        Assert.assertEquals(RpcCode.VERSION_ERROR.ordinal(), this.stub.getTaskIndex(GetTaskIndexRequest.newBuilder().setVersion(1L).build()).getCode());
    }

    @Test
    public void testGetFinishedNode() throws IOException {
        AMMeta aMMeta = (AMMeta) Mockito.mock(AMMeta.class);
        Mockito.when(aMMeta.restoreFinishClusterDef()).thenReturn(MLClusterDef.newBuilder().addJob(MLJobDef.newBuilder().setName("worker").putTasks(0, NodeSpec.newBuilder().build()).putTasks(1, NodeSpec.newBuilder().build())).build());
        Mockito.when(this.appMasterServer.getAmMeta()).thenReturn(aMMeta);
        GetFinishNodeResponse finishedNode = this.stub.getFinishedNode(GetFinishedNodeRequest.newBuilder().build());
        Assert.assertEquals(RpcCode.OK.ordinal(), finishedNode.getCode());
        Assert.assertEquals(2L, finishedNode.getWorkersCount());
    }
}
