package org.apache.twill.yarn;

import com.google.common.base.Charsets;
import com.google.common.collect.Sets;
import com.google.common.io.LineReader;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.twill.api.Command;
import org.apache.twill.api.ResourceSpecification;
import org.apache.twill.api.TwillController;
import org.apache.twill.api.logging.PrinterLogHandler;
import org.apache.twill.discovery.Discoverable;
import org.apache.twill.discovery.ServiceDiscovered;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/twill/yarn/FailureRestartTestRun.class */
public final class FailureRestartTestRun extends BaseYarnTest {

    /* loaded from: input_file:org/apache/twill/yarn/FailureRestartTestRun$FailureRunnable.class */
    public static final class FailureRunnable extends SocketServer {
        private volatile boolean killed;

        @Override // org.apache.twill.yarn.SocketServer
        public void run() {
            this.killed = false;
            super.run();
            if (this.killed) {
                throw new RuntimeException("Exception");
            }
        }

        public void handleCommand(Command command) throws Exception {
            if (command.getCommand().equals("kill" + getContext().getInstanceId())) {
                this.killed = true;
                this.running = false;
                this.serverSocket.close();
            }
        }

        @Override // org.apache.twill.yarn.SocketServer
        public void handleRequest(BufferedReader bufferedReader, PrintWriter printWriter) throws IOException {
            printWriter.println(getContext().getInstanceId() + bufferedReader.readLine());
            printWriter.flush();
        }
    }

    @Test
    public void testFailureRestart() throws Exception {
        TwillController start = getTwillRunner().prepare(new FailureRunnable(), ResourceSpecification.Builder.with().setVirtualCores(1).setMemory(512, ResourceSpecification.SizeUnit.MEGA).setInstances(2).build()).withApplicationArguments(new String[]{"failure"}).withArguments(FailureRunnable.class.getSimpleName(), new String[]{"failure2"}).addLogHandler(new PrinterLogHandler(new PrintWriter((OutputStream) System.out, true))).start();
        ServiceDiscovered discoverService = start.discoverService("failure");
        Assert.assertTrue(waitForSize(discoverService, 2, 120));
        Assert.assertEquals(Sets.newHashSet(new Integer[]{0, 1}), getInstances(discoverService));
        start.sendCommand(FailureRunnable.class.getSimpleName(), Command.Builder.of("kill0").build());
        Assert.assertTrue(waitForSize(discoverService, 1, 120));
        Assert.assertTrue(waitForSize(discoverService, 2, 120));
        Assert.assertEquals(Sets.newHashSet(new Integer[]{0, 1}), getInstances(discoverService));
        start.terminate().get(120L, TimeUnit.SECONDS);
    }

    private Set<Integer> getInstances(Iterable<Discoverable> iterable) throws IOException {
        HashSet newHashSet = Sets.newHashSet();
        Iterator<Discoverable> it = iterable.iterator();
        while (it.hasNext()) {
            InetSocketAddress socketAddress = it.next().getSocketAddress();
            Socket socket = new Socket(socketAddress.getAddress(), socketAddress.getPort());
            Throwable th = null;
            try {
                try {
                    PrintWriter printWriter = new PrintWriter((Writer) new OutputStreamWriter(socket.getOutputStream(), Charsets.UTF_8), true);
                    LineReader lineReader = new LineReader(new InputStreamReader(socket.getInputStream(), Charsets.UTF_8));
                    printWriter.println("Failure");
                    String readLine = lineReader.readLine();
                    Assert.assertTrue(readLine.endsWith("Failure"));
                    newHashSet.add(Integer.valueOf(Integer.parseInt(readLine.substring(0, readLine.length() - "Failure".length()))));
                    if (socket != null) {
                        if (0 != 0) {
                            try {
                                socket.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            socket.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (socket != null) {
                    if (th != null) {
                        try {
                            socket.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        socket.close();
                    }
                }
                throw th3;
            }
        }
        return newHashSet;
    }
}
