package org.linqs.psl.model.deep;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.RandomAccessFile;
import java.io.Writer;
import java.net.Socket;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.Map;
import org.apache.logging.log4j.core.jackson.JsonConstants;
import org.json.JSONObject;
import org.linqs.psl.config.Config;
import org.linqs.psl.config.Options;
import org.linqs.psl.runtime.RuntimeConfig;
import org.linqs.psl.util.FileUtils;
import org.linqs.psl.util.Logger;

/* loaded from: input_file:org/linqs/psl/model/deep/DeepModel.class */
public abstract class DeepModel {
    protected static final String CONFIG_MODEL_PATH = "model-path";
    protected static final String CONFIG_RELATIVE_DIR = "relative-dir";
    private static final long SERVER_SLEEP_TIME_MS = 500;
    protected String deepModel;
    protected Map<String, String> pythonOptions = new HashMap();
    protected String application = null;
    protected int port = getOpenPort(this);
    protected String pythonModule = Options.PREDICATE_DEEP_PYTHON_WRAPPER_MODULE.getString();
    protected String sharedMemoryPath = Options.PREDICATE_DEEP_SHARED_MEMORY_PATH.getString();
    protected Process pythonServerProcess = null;
    protected RandomAccessFile sharedFile = null;
    protected MappedByteBuffer sharedBuffer = null;
    protected Socket socket = null;
    protected BufferedReader socketInput = null;
    protected PrintWriter socketOutput = null;
    protected boolean serverOpen = false;
    private static final Logger log = Logger.getLogger(DeepModel.class);
    private static int startingPort = Options.PREDICATE_DEEP_PYTHON_PORT.getInt();
    private static Map<Integer, DeepModel> usedPorts = new HashMap();

    /* JADX INFO: Access modifiers changed from: protected */
    public DeepModel(String str) {
        this.deepModel = str;
    }

    public abstract int init();

    public abstract void writeFitData();

    public abstract void writePredictData();

    public abstract float readPredictData();

    public abstract void writeEvalData();

    public void initDeepModel(String str) {
        log.debug("Init deep model {}.", this);
        this.application = str;
        this.pythonOptions.put(CONFIG_RELATIVE_DIR, Config.getString("runtime.relativebasepath", null));
        int init = init();
        if (this.pythonOptions.get(CONFIG_MODEL_PATH) == null) {
            throw new IllegalArgumentException(String.format("A DeepModel must have a model path (\"%s\") specified in predicate config.", CONFIG_MODEL_PATH));
        }
        if (this.pythonServerProcess == null) {
            log.debug("DeepModel server not found for {}. Starting server.", this);
            initServer(init);
        }
        JSONObject jSONObject = new JSONObject();
        jSONObject.put("task", "init");
        jSONObject.put("deep_model", this.deepModel);
        jSONObject.put("shared_memory_path", this.sharedMemoryPath);
        jSONObject.put("application", str);
        jSONObject.put(RuntimeConfig.KEY_OPTIONS, (Map<?, ?>) this.pythonOptions);
        log.debug("Sending init message to deep model server for {}.", this);
        log.debug("Message: {}", jSONObject);
        log.debug("Init deep model results for {} : {}", this, getResultString(sendSocketMessage(jSONObject)));
    }

    public void fitDeepModel() {
        log.debug("Fit deep model {}.", this);
        this.sharedBuffer.clear();
        writeFitData();
        this.sharedBuffer.force();
        JSONObject jSONObject = new JSONObject();
        jSONObject.put("task", "fit");
        jSONObject.put("deep_model", this.deepModel);
        jSONObject.put(RuntimeConfig.KEY_OPTIONS, (Map<?, ?>) this.pythonOptions);
        log.debug("Fit deep model results for {} : {}", this, getResultString(sendSocketMessage(jSONObject)));
    }

    public float predictDeepModel(Boolean bool) {
        log.debug("Predict deep model {}.", this);
        this.sharedBuffer.clear();
        writePredictData();
        this.sharedBuffer.force();
        JSONObject jSONObject = new JSONObject();
        if (bool.booleanValue()) {
            jSONObject.put("task", "predict_learn");
        } else {
            jSONObject.put("task", "predict");
        }
        jSONObject.put("deep_model", this.deepModel);
        jSONObject.put(RuntimeConfig.KEY_OPTIONS, (Map<?, ?>) this.pythonOptions);
        JSONObject sendSocketMessage = sendSocketMessage(jSONObject);
        this.sharedBuffer.clear();
        float readPredictData = readPredictData();
        log.debug("Predict deep model result for {} : {}", this, getResultString(sendSocketMessage));
        return readPredictData;
    }

    public void evalDeepModel() {
        log.debug("Eval deep model {}.", this);
        this.sharedBuffer.clear();
        writeEvalData();
        this.sharedBuffer.force();
        JSONObject jSONObject = new JSONObject();
        jSONObject.put("task", "eval");
        jSONObject.put("deep_model", this.deepModel);
        jSONObject.put(RuntimeConfig.KEY_OPTIONS, (Map<?, ?>) this.pythonOptions);
        log.debug("Eval deep model result for {} : {}", this, getResultString(sendSocketMessage(jSONObject)));
    }

    public void saveDeepModel() {
        log.debug("Save deep model {}.", this);
        JSONObject jSONObject = new JSONObject();
        jSONObject.put("task", "save");
        jSONObject.put(RuntimeConfig.KEY_OPTIONS, (Map<?, ?>) this.pythonOptions);
        log.debug("Save deep model result for {} : {}", this, getResultString(sendSocketMessage(jSONObject)));
    }

    public void close() {
        log.debug("Close deep model {}.", this);
        if (this.pythonOptions != null) {
            this.pythonOptions.clear();
        }
        if (this.socketOutput != null) {
            JSONObject jSONObject = new JSONObject();
            jSONObject.put("task", "close");
            log.debug("Close deep model result for {} : {}", this, getResultString(sendSocketMessage(jSONObject)));
        }
        closeServer();
    }

    private String getResultString(JSONObject jSONObject) {
        JSONObject optJSONObject = jSONObject.optJSONObject("result");
        return optJSONObject == null ? "<No Result Provided>" : optJSONObject.toString();
    }

    private void initServer(int i) {
        Runtime.getRuntime().addShutdownHook(new Thread() { // from class: org.linqs.psl.model.deep.DeepModel.1
            @Override // java.lang.Thread, java.lang.Runnable
            public void run() {
                this.close();
            }
        });
        try {
            this.sharedFile = new RandomAccessFile(this.sharedMemoryPath, "rw");
            try {
                this.sharedBuffer = this.sharedFile.getChannel().map(FileChannel.MapMode.READ_WRITE, 0L, i);
                this.sharedBuffer.clear();
                ProcessBuilder processBuilder = new ProcessBuilder("python3", "-m", this.pythonModule, "" + this.port);
                processBuilder.inheritIO();
                this.pythonServerProcess = processBuilder.start();
                sleepForServer();
                this.serverOpen = true;
                this.socket = new Socket("127.0.0.1", this.port);
                this.socketInput = new BufferedReader(new InputStreamReader(this.socket.getInputStream(), FileUtils.DEFAULT_CHARSET));
                this.socketOutput = new PrintWriter((Writer) new OutputStreamWriter(this.socket.getOutputStream(), FileUtils.DEFAULT_CHARSET), true);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } catch (FileNotFoundException e2) {
            throw new RuntimeException("Could not open random access file: " + this.sharedMemoryPath, e2);
        }
    }

    private void closeServer() {
        if (this.socketOutput != null) {
            this.serverOpen = false;
            sleepForServer();
            freePort(this.port);
            this.socketOutput.close();
            this.socketOutput = null;
        }
        if (this.socketInput != null) {
            try {
                this.socketInput.close();
                this.socketInput = null;
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        if (this.socket != null) {
            if (!this.socket.isClosed()) {
                try {
                    this.socket.close();
                } catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
            }
            this.socket = null;
        }
        if (this.sharedBuffer != null) {
            this.sharedBuffer = null;
        }
        if (this.sharedFile != null) {
            try {
                this.sharedFile.close();
                FileUtils.delete(this.sharedMemoryPath);
                this.sharedFile = null;
            } catch (IOException e3) {
                throw new RuntimeException("Failed to clean up shared file: " + this.sharedMemoryPath, e3);
            }
        }
        if (this.pythonServerProcess != null) {
            if (this.pythonServerProcess.isAlive()) {
                this.pythonServerProcess.destroyForcibly();
            }
            this.pythonServerProcess = null;
        }
    }

    private void sleepForServer() {
        try {
            Thread.sleep(SERVER_SLEEP_TIME_MS);
        } catch (InterruptedException e) {
        }
    }

    private JSONObject sendSocketMessage(JSONObject jSONObject) {
        if (!this.serverOpen) {
            return null;
        }
        log.trace(String.format("Sending server message: '%s'.", jSONObject.toString()));
        try {
            this.socketOutput.println(jSONObject.toString());
            String readLine = this.socketInput.readLine();
            log.trace(String.format("Received server message: '%s'.", readLine));
            JSONObject jSONObject2 = new JSONObject(readLine);
            String optString = jSONObject2.optString("status", "<UNKNOWN>");
            if (optString.equals("success")) {
                return jSONObject2;
            }
            this.serverOpen = false;
            sleepForServer();
            throw new RuntimeException(String.format("Server sent a failure status (%s): '%s'.", optString, jSONObject2.optString(JsonConstants.ELT_MESSAGE, "<no message provided>")));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static synchronized int getOpenPort(DeepModel deepModel) {
        int i = startingPort;
        while (usedPorts.containsKey(Integer.valueOf(i))) {
            i++;
        }
        usedPorts.put(Integer.valueOf(i), deepModel);
        return i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static synchronized void freePort(int i) {
        usedPorts.remove(Integer.valueOf(i));
    }
}
