package org.linqs.psl.application.inference.online;

import java.io.EOFException;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.Thread;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import org.linqs.psl.application.inference.online.messages.ModelInformation;
import org.linqs.psl.application.inference.online.messages.OnlineMessage;
import org.linqs.psl.application.inference.online.messages.actions.controls.Exit;
import org.linqs.psl.application.inference.online.messages.actions.controls.Stop;
import org.linqs.psl.application.inference.online.messages.responses.ActionStatus;
import org.linqs.psl.application.inference.online.messages.responses.OnlineResponse;
import org.linqs.psl.config.Options;
import org.linqs.psl.model.predicate.FunctionalPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.FileUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.SystemUtils;

/* loaded from: input_file:org/linqs/psl/application/inference/online/OnlineServer.class */
public class OnlineServer {
    private static final Logger log = Logger.getLogger(OnlineServer.class);
    public static final String TEMP_FILE_DIR_PREFIX = "onlinePSLServer";
    public static final String TEMP_FILE_NAME = "onlinePSLServer.lock";
    private List<Rule> rules;
    private boolean listening = false;
    private ServerConnectionThread serverThread = new ServerConnectionThread();
    private File tempFile = null;
    private BlockingQueue<OnlineMessage> queue = new LinkedBlockingQueue();
    private ConcurrentMap<UUID, ClientConnectionThread> messageIDConnectionMap = new ConcurrentHashMap();
    private Set<ClientConnectionThread> clientConnectionThreads = Collections.newSetFromMap(new ConcurrentHashMap());

    /* loaded from: input_file:org/linqs/psl/application/inference/online/OnlineServer$ClientConnectionExceptionHandler.class */
    private class ClientConnectionExceptionHandler implements Thread.UncaughtExceptionHandler {
        private ClientConnectionExceptionHandler() {
        }

        @Override // java.lang.Thread.UncaughtExceptionHandler
        public void uncaughtException(Thread thread, Throwable th) {
            if (!(thread instanceof ClientConnectionThread)) {
                throw new RuntimeException("ClientConnectionExceptionHandler can only be used by ClientConnectionThreads.", th);
            }
            OnlineServer.log.warn(String.format("Uncaught exception in ClientConnectionThread.  Exception message: %s", th.getMessage()));
            OnlineServer.this.closeClient((ClientConnectionThread) thread);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/linqs/psl/application/inference/online/OnlineServer$ClientConnectionThread.class */
    public class ClientConnectionThread extends Thread {
        public Socket socket;
        public ObjectInputStream inputStream;
        public ObjectOutputStream outputStream;

        public ClientConnectionThread(Socket socket) {
            this.socket = socket;
            setUncaughtExceptionHandler(new ClientConnectionExceptionHandler());
        }

        private void initializeConnection() {
            try {
                this.inputStream = new ObjectInputStream(this.socket.getInputStream());
                this.outputStream = new ObjectOutputStream(this.socket.getOutputStream());
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        private void sendModelInformation() {
            ArrayList<Predicate> arrayList = new ArrayList(Predicate.getAll());
            ArrayList arrayList2 = new ArrayList();
            for (Predicate predicate : arrayList) {
                if (!(predicate instanceof FunctionalPredicate)) {
                    arrayList2.add(predicate);
                }
            }
            try {
                this.outputStream.writeObject(new ModelInformation((Predicate[]) arrayList2.toArray(new Predicate[0]), (Rule[]) OnlineServer.this.rules.toArray(new Rule[0])));
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            OnlineMessage onlineMessage;
            initializeConnection();
            sendModelInformation();
            while (true) {
                try {
                    onlineMessage = (OnlineMessage) this.inputStream.readObject();
                    OnlineServer.log.trace(String.format("Server received action from client: %s", onlineMessage));
                    try {
                        OnlineServer.this.messageIDConnectionMap.put(onlineMessage.getIdentifier(), this);
                        OnlineServer.this.queue.put(onlineMessage);
                    } catch (InterruptedException e) {
                    }
                } catch (EOFException e2) {
                    throw new RuntimeException("Client closed socket without Exit or Stop action.", e2);
                } catch (IOException e3) {
                    if (!this.socket.isClosed()) {
                        throw new RuntimeException(e3);
                    }
                    return;
                } catch (ClassNotFoundException e4) {
                    OnlineServer.log.warn("Failed to deserialized last OnlineMessage from client.");
                }
                if ((onlineMessage instanceof Exit) || (onlineMessage instanceof Stop)) {
                    return;
                }
            }
        }

        public void close() {
            try {
                this.socket.close();
            } catch (IOException e) {
            }
        }
    }

    /* loaded from: input_file:org/linqs/psl/application/inference/online/OnlineServer$ServerConnectionThread.class */
    private class ServerConnectionThread extends Thread {
        private int port = Options.ONLINE_PORT_NUMBER.getInt();
        private ServerSocket socket = null;
        private Semaphore readyLock = new Semaphore(1);

        public ServerConnectionThread() {
            try {
                this.readyLock.acquire();
            } catch (InterruptedException e) {
                throw new RuntimeException("Unable to acquire a new lock.", e);
            }
        }

        private void openListenSocket() {
            try {
                this.socket = new ServerSocket(this.port);
            } catch (IOException e) {
                throw new RuntimeException(String.format("Could not establish socket on port %s.", Integer.valueOf(this.port)));
            }
        }

        public void blockUntilReady() {
            try {
                this.readyLock.acquire();
                this.readyLock.release();
            } catch (InterruptedException e) {
                throw new RuntimeException("Unable to acquire ready lock.", e);
            }
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            openListenSocket();
            OnlineServer.this.createServerTempFile();
            this.readyLock.release();
            OnlineServer.log.info(String.format("Online server started on port %s.", Integer.valueOf(this.port)));
            while (OnlineServer.this.listening) {
                try {
                    ClientConnectionThread clientConnectionThread = new ClientConnectionThread(this.socket.accept());
                    OnlineServer.this.addClient(clientConnectionThread);
                    clientConnectionThread.start();
                } catch (IOException e) {
                    if (!this.socket.isClosed()) {
                        throw new RuntimeException(e);
                    }
                }
            }
        }

        public void close() {
            if (this.socket != null) {
                try {
                    this.socket.close();
                } catch (IOException e) {
                }
            }
        }
    }

    public OnlineServer(List<Rule> list) {
        this.rules = list;
    }

    public void start() {
        this.listening = true;
        this.serverThread.start();
        this.serverThread.blockUntilReady();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void createServerTempFile() {
        String tempDir = SystemUtils.getTempDir(TEMP_FILE_DIR_PREFIX);
        FileUtils.mkdir(tempDir);
        this.tempFile = new File(new File(tempDir), TEMP_FILE_NAME);
        try {
            if (!this.tempFile.createNewFile()) {
                throw new IllegalStateException(String.format("Temp file already exists at: %s", this.tempFile.getAbsolutePath()));
            }
            this.tempFile.deleteOnExit();
            log.debug(String.format("Temporary server config file at: %s", this.tempFile.getAbsolutePath()));
        } catch (IOException e) {
            throw new RuntimeException(String.format("Error creating temp file at: %s", this.tempFile.getAbsolutePath()), e);
        }
    }

    public OnlineMessage getAction() {
        OnlineMessage take;
        do {
            try {
                take = this.queue.take();
                if (take instanceof Exit) {
                    onActionExecution(take, new ActionStatus(take, true, "Session Closed."));
                    take = null;
                }
            } catch (InterruptedException e) {
                log.warn("Interrupted while taking an online action from the queue.", e);
                return null;
            }
        } while (take == null);
        return take;
    }

    public void onActionExecution(OnlineMessage onlineMessage, OnlineResponse onlineResponse) {
        ClientConnectionThread clientConnectionThread = this.messageIDConnectionMap.get(onlineMessage.getIdentifier());
        try {
            clientConnectionThread.outputStream.writeObject(onlineResponse);
        } catch (IOException e) {
            log.warn(String.format("Failed to send client onlineResponse: %s", onlineResponse), e);
        }
        if ((onlineMessage instanceof Exit) || (onlineMessage instanceof Stop)) {
            closeClient(clientConnectionThread);
        }
        if (onlineResponse instanceof ActionStatus) {
            this.messageIDConnectionMap.remove(onlineMessage.getIdentifier());
        }
    }

    public void closeClient(ClientConnectionThread clientConnectionThread) {
        clientConnectionThread.close();
        this.clientConnectionThreads.remove(clientConnectionThread);
    }

    public void addClient(ClientConnectionThread clientConnectionThread) {
        this.clientConnectionThreads.add(clientConnectionThread);
    }

    public void close() {
        this.listening = false;
        if (this.tempFile != null) {
            FileUtils.delete(this.tempFile);
            this.tempFile = null;
        }
        if (this.serverThread != null) {
            this.serverThread.close();
            this.serverThread = null;
        }
        if (this.clientConnectionThreads != null) {
            Iterator<ClientConnectionThread> it = this.clientConnectionThreads.iterator();
            while (it.hasNext()) {
                closeClient(it.next());
            }
            this.clientConnectionThreads = null;
        }
        if (this.queue != null) {
            this.queue.clear();
            this.queue = null;
        }
    }
}
