package org.linqs.psl.runtime;

import java.io.BufferedWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.config.Config;
import org.linqs.psl.config.Option;
import org.linqs.psl.config.Options;
import org.linqs.psl.config.RuntimeOptions;
import org.linqs.psl.database.DataStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.Partition;
import org.linqs.psl.database.loading.Inserter;
import org.linqs.psl.database.rdbms.RDBMSDataStore;
import org.linqs.psl.database.rdbms.driver.DatabaseDriver;
import org.linqs.psl.database.rdbms.driver.H2DatabaseDriver;
import org.linqs.psl.database.rdbms.driver.PostgreSQLDriver;
import org.linqs.psl.database.rdbms.driver.SQLiteDriver;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.grounding.Grounding;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.runtime.RuntimeConfig;
import org.linqs.psl.util.FileUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.Reflection;
import org.linqs.psl.util.StringUtils;
import org.linqs.psl.util.Version;

/* loaded from: input_file:org/linqs/psl/runtime/Runtime.class */
public class Runtime {
    private static final Logger log = Logger.getLogger(Runtime.class);
    public static final String PARTITION_NAME_OBSERVATIONS = "observations";
    public static final String PARTITION_NAME_TARGET = "targets";
    public static final String PARTITION_NAME_TRUTH = "truth";
    private static final String[] PARTITION_NAMES = {PARTITION_NAME_OBSERVATIONS, PARTITION_NAME_TARGET, PARTITION_NAME_TRUTH};

    /* loaded from: input_file:org/linqs/psl/runtime/Runtime$DatabaseType.class */
    public enum DatabaseType {
        H2,
        Postgres,
        SQLite
    }

    /* loaded from: input_file:org/linqs/psl/runtime/Runtime$GroundRuleOutputter.class */
    public static class GroundRuleOutputter implements Grounding.GroundRuleCallback {
        private volatile boolean headerWritten = false;
        private PrintWriter out;
        private boolean closeOut;

        public GroundRuleOutputter(String str) {
            this.out = new PrintWriter(System.out);
            if (str != null) {
                this.out = new PrintWriter(FileUtils.getBufferedWriter(str));
            }
        }

        @Override // org.linqs.psl.grounding.Grounding.GroundRuleCallback
        public void call(GroundRule groundRule) {
            String join;
            if (groundRule instanceof WeightedGroundRule) {
                WeightedGroundRule weightedGroundRule = (WeightedGroundRule) groundRule;
                join = StringUtils.join(Inserter.DEFAULT_DELIMITER, "" + weightedGroundRule.getWeight(), "" + weightedGroundRule.isSquared(), groundRule.baseToString());
            } else {
                join = StringUtils.join(Inserter.DEFAULT_DELIMITER, ".", "false", groundRule.baseToString());
            }
            output(join);
        }

        private synchronized void output(String str) {
            if (str == null) {
                return;
            }
            if (!this.headerWritten) {
                this.headerWritten = true;
                this.out.println(StringUtils.join(Inserter.DEFAULT_DELIMITER, "Weight", "Squared?", "Rule"));
            }
            this.out.println(str);
        }

        public void close() {
            output(null);
            this.out.flush();
            if (this.closeOut) {
                this.out.close();
            }
        }
    }

    public Runtime() {
        initLogger();
    }

    public RuntimeResult run() {
        return run(new RuntimeConfig());
    }

    public RuntimeResult run(String str) {
        return run(str, false);
    }

    public RuntimeResult run(String str, boolean z) {
        return run(RuntimeConfig.fromFile(str), z);
    }

    public RuntimeResult run(RuntimeConfig runtimeConfig) {
        return run(runtimeConfig, false);
    }

    public static String serializedRun(String str, String str2) {
        Runtime runtime = new Runtime();
        RuntimeResult run = runtime.run(RuntimeConfig.fromJSON(str, str2), true);
        runtime.cleanup();
        return run.toJSON();
    }

    public RuntimeResult run(RuntimeConfig runtimeConfig, boolean z) {
        Config.pushLayer();
        try {
            RuntimeResult runInternal = runInternal(runtimeConfig, z);
            Config.popLayer();
            cleanup();
            return runInternal;
        } catch (Throwable th) {
            Config.popLayer();
            cleanup();
            throw th;
        }
    }

    protected RuntimeResult runInternal(RuntimeConfig runtimeConfig, boolean z) {
        RuntimeResult runtimeResult = z ? new RuntimeResult() : null;
        for (Map.Entry<String, String> entry : runtimeConfig.options.entrySet()) {
            Config.setProperty(entry.getKey(), entry.getValue(), false);
        }
        initLogger();
        if (checkHelp() || checkVersion()) {
            return runtimeResult;
        }
        log.info("PSL Runtime Version {}", Version.getFull());
        runtimeConfig.validate();
        for (Map.Entry<String, String> entry2 : runtimeConfig.options.entrySet()) {
            Config.setProperty(entry2.getKey(), entry2.getValue(), false);
        }
        Config.setProperty("runtime.relativebasepath", runtimeConfig.relativeBasePath, false);
        for (RuntimeConfig.PredicateConfigInfo predicateConfigInfo : runtimeConfig.predicates.values()) {
            Predicate predicate = Predicate.get(predicateConfigInfo.name);
            for (Map.Entry<String, String> entry3 : predicateConfigInfo.options.entrySet()) {
                predicate.setPredicateOption(entry3.getKey(), entry3.getValue());
            }
        }
        Model runLearning = RuntimeOptions.LEARN.getBoolean() ? runLearning(runtimeConfig, runtimeResult) : null;
        if (RuntimeOptions.INFERENCE.getBoolean()) {
            runInference(runtimeConfig, runLearning, runtimeResult);
        }
        for (Predicate predicate2 : Predicate.getAll()) {
            if (predicate2 instanceof DeepPredicate) {
                predicate2.close();
            }
        }
        return runtimeResult;
    }

    protected boolean checkHelp() {
        if (!RuntimeOptions.HELP.getBoolean()) {
            return false;
        }
        System.out.println("PSL Runtime Version " + Version.getFull());
        System.out.println("Options used by the PSL runtime:");
        List<Option> fetchClassOptions = Options.fetchClassOptions(RuntimeOptions.class);
        Collections.sort(fetchClassOptions);
        Iterator<Option> it = fetchClassOptions.iterator();
        while (it.hasNext()) {
            System.out.println("    " + it.next().toString());
        }
        return true;
    }

    protected boolean checkVersion() {
        if (!RuntimeOptions.VERSION.getBoolean()) {
            return false;
        }
        System.out.println("PSL Version " + Version.getFull());
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void cleanup() {
        Parallel.close();
    }

    protected void evaluate(Database database, Database database2, List<EvaluationInstance> list, RuntimeResult runtimeResult) {
        for (EvaluationInstance evaluationInstance : list) {
            evaluationInstance.compute(database, database2);
            log.info("Evaluation results: {}", evaluationInstance.getOutput());
            if (runtimeResult != null) {
                runtimeResult.addEvaluation(evaluationInstance.getOutput());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DataStore initDataStore(RuntimeConfig runtimeConfig) {
        DatabaseDriver sQLiteDriver;
        switch (DatabaseType.valueOf(RuntimeOptions.DB_TYPE.getString())) {
            case H2:
                String string = RuntimeOptions.DB_H2_PATH.getString();
                H2DatabaseDriver.Type type = H2DatabaseDriver.Type.Disk;
                if (RuntimeOptions.DB_H2_INMEMORY.getBoolean()) {
                    type = H2DatabaseDriver.Type.Memory;
                }
                sQLiteDriver = new H2DatabaseDriver(type, string, true);
                break;
            case Postgres:
                sQLiteDriver = new PostgreSQLDriver(RuntimeOptions.DB_PG_NAME.getString(), true);
                break;
            case SQLite:
                sQLiteDriver = new SQLiteDriver(RuntimeOptions.DB_SQLITE_INMEMORY.getBoolean(), RuntimeOptions.DB_SQLITE_PATH.getString(), true);
                break;
            default:
                throw new IllegalStateException("Unknown database type: " + RuntimeOptions.DB_TYPE.getString());
        }
        RDBMSDataStore rDBMSDataStore = new RDBMSDataStore(sQLiteDriver);
        Iterator<RuntimeConfig.PredicateConfigInfo> it = runtimeConfig.predicates.values().iterator();
        while (it.hasNext()) {
            Predicate predicate = Predicate.get(it.next().name);
            if (predicate instanceof StandardPredicate) {
                rDBMSDataStore.registerPredicate((StandardPredicate) predicate);
            }
        }
        return rDBMSDataStore;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initLogger() {
        if (RuntimeOptions.LOG_LEVEL.isSet()) {
            Logger.setLevel(RuntimeOptions.LOG_LEVEL.getString());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void loadData(DataStore dataStore, RuntimeConfig runtimeConfig, String str) {
        log.debug("Data loading start");
        for (RuntimeConfig.PredicateConfigInfo predicateConfigInfo : runtimeConfig.predicates.values()) {
            if (predicateConfigInfo.dataSize() != 0) {
                StandardPredicate standardPredicate = StandardPredicate.get(predicateConfigInfo.name);
                List asList = Arrays.asList(predicateConfigInfo.observations.getDataPaths(str), predicateConfigInfo.targets.getDataPaths(str), predicateConfigInfo.truth.getDataPaths(str));
                List asList2 = Arrays.asList(predicateConfigInfo.observations.getDataPoints(str), predicateConfigInfo.targets.getDataPoints(str), predicateConfigInfo.truth.getDataPoints(str));
                for (int i = 0; i < PARTITION_NAMES.length; i++) {
                    if (RuntimeConfig.KEY_VALIDATION.equals(str)) {
                        loadDataPaths(dataStore, standardPredicate, String.format("%s_%s", RuntimeConfig.KEY_VALIDATION, PARTITION_NAMES[i]), (Iterable) asList.get(i));
                        loadDataPoints(dataStore, standardPredicate, String.format("%s_%s", RuntimeConfig.KEY_VALIDATION, PARTITION_NAMES[i]), (Iterable) asList2.get(i));
                    } else {
                        loadDataPaths(dataStore, standardPredicate, PARTITION_NAMES[i], (Iterable) asList.get(i));
                        loadDataPoints(dataStore, standardPredicate, PARTITION_NAMES[i], (Iterable) asList2.get(i));
                    }
                }
            }
        }
        log.debug("Data loading complete");
    }

    protected void loadDataPaths(DataStore dataStore, StandardPredicate standardPredicate, String str, Iterable<String> iterable) {
        Inserter inserter = dataStore.getInserter(standardPredicate, dataStore.getPartition(str));
        for (String str2 : iterable) {
            log.debug("Loading data for {} ({} partition) from {}", standardPredicate, str, str2);
            inserter.loadDelimitedDataAutomatic(str2);
        }
    }

    protected void loadDataPoints(DataStore dataStore, StandardPredicate standardPredicate, String str, Iterable<List<String>> iterable) {
        Iterator<List<String>> it = iterable.iterator();
        if (it.hasNext()) {
            Inserter inserter = dataStore.getInserter(standardPredicate, dataStore.getPartition(str));
            log.debug("Loading embedded data for {} ({} partition)", standardPredicate, str);
            int arity = standardPredicate.getArity();
            Object[] objArr = new Object[arity];
            int i = 0;
            while (it.hasNext()) {
                List<String> next = it.next();
                if (next.size() < arity || next.size() > arity + 1) {
                    throw new IllegalArgumentException(String.format("Provided data point for predicate %s does not have the correct number of arguments. Expecting %d or %d arguments. Offending data point: %s.", standardPredicate.getName(), Integer.valueOf(arity), Integer.valueOf(arity + 1), next));
                }
                for (int i2 = 0; i2 < arity; i2++) {
                    objArr[i2] = next.get(i2);
                }
                if (next.size() == arity + 1) {
                    inserter.insertValue(Double.parseDouble(next.get(arity)), objArr);
                } else {
                    inserter.insert(objArr);
                }
                i++;
            }
            log.trace("Loaded {} rows of embeded data for {} ({} partition)", Integer.valueOf(i), standardPredicate, str);
        }
    }

    protected void runInference(RuntimeConfig runtimeConfig, Model model, RuntimeResult runtimeResult) {
        Config.pushLayer();
        try {
            runInferenceInternal(runtimeConfig, model, runtimeResult);
        } finally {
            Config.popLayer();
        }
    }

    protected void runInferenceInternal(RuntimeConfig runtimeConfig, Model model, RuntimeResult runtimeResult) {
        for (Map.Entry<String, String> entry : runtimeConfig.infer.options.entrySet()) {
            Config.setProperty(entry.getKey(), entry.getValue(), false);
        }
        if (RuntimeOptions.INFERENCE_CLEAR_RULES.getBoolean()) {
            model.clear();
        }
        if (model == null) {
            model = new Model();
            Iterator<Rule> it = runtimeConfig.rules.getRules().iterator();
            while (it.hasNext()) {
                model.addRule(it.next());
            }
        }
        Iterator<Rule> it2 = runtimeConfig.infer.rules.getRules().iterator();
        while (it2.hasNext()) {
            model.addRule(it2.next());
        }
        if (model.getRules().size() == 0) {
            throw new RuntimeException("No rules found for inference.");
        }
        log.debug("Model:");
        Iterator<Rule> it3 = model.getRules().iterator();
        while (it3.hasNext()) {
            log.debug("   " + it3.next());
        }
        DataStore initDataStore = initDataStore(runtimeConfig);
        loadData(initDataStore, runtimeConfig, "infer");
        Set<StandardPredicate> closedPredicates = runtimeConfig.getClosedPredicates("infer");
        Partition partition = initDataStore.getPartition(PARTITION_NAME_TARGET);
        Partition partition2 = initDataStore.getPartition(PARTITION_NAME_OBSERVATIONS);
        Partition partition3 = initDataStore.getPartition(PARTITION_NAME_TRUTH);
        Database database = initDataStore.getDatabase(partition, closedPredicates, partition2);
        Database database2 = initDataStore.getDatabase(partition3, initDataStore.getRegisteredPredicates(), new Partition[0]);
        List<EvaluationInstance> evaluations = getEvaluations(runtimeConfig);
        GroundRuleOutputter groundRuleOutputter = null;
        if (RuntimeOptions.INFERENCE_OUTPUT_GROUNDRULES.getBoolean()) {
            groundRuleOutputter = new GroundRuleOutputter(RuntimeOptions.INFERENCE_OUTPUT_GROUNDRULES_PATH.getString());
            Grounding.setGroundRuleCallback(groundRuleOutputter);
        }
        InferenceApplication inferenceApplication = InferenceApplication.getInferenceApplication(RuntimeOptions.INFERENCE_METHOD.getString(), model.getRules(), database);
        inferenceApplication.loadDeepPredicates("inference");
        inferenceApplication.inference(RuntimeOptions.INFERENCE_COMMIT.getBoolean(), false, evaluations, database2);
        if (groundRuleOutputter != null) {
            groundRuleOutputter.close();
            Grounding.setGroundRuleCallback(null);
        }
        if (RuntimeOptions.INFERENCE_OUTPUT_RESULTS.getBoolean()) {
            String string = RuntimeOptions.INFERENCE_OUTPUT_RESULTS_DIR.getString();
            if (string == null) {
                log.info("Writing inferred predicates to stdout.");
                database.outputRandomVariableAtoms();
            } else {
                log.info("Writing inferred predicates to directory: " + string);
                database.outputRandomVariableAtoms(string);
            }
        }
        if (runtimeResult != null) {
            if (RuntimeOptions.OUTPUT_ALL_ATOMS.getBoolean()) {
                Iterator<GroundAtom> it4 = database.getAtomStore().iterator();
                while (it4.hasNext()) {
                    runtimeResult.addAtom(it4.next());
                }
            } else {
                Iterator<RandomVariableAtom> it5 = database.getAtomStore().getRandomVariableAtoms().iterator();
                while (it5.hasNext()) {
                    runtimeResult.addAtom(it5.next());
                }
            }
        }
        evaluate(database, database2, evaluations, runtimeResult);
        inferenceApplication.close();
        database.close();
        database2.close();
        initDataStore.close();
    }

    protected Model runLearning(RuntimeConfig runtimeConfig, RuntimeResult runtimeResult) {
        Config.pushLayer();
        try {
            Model runLearningInternal = runLearningInternal(runtimeConfig, runtimeResult);
            Config.popLayer();
            return runLearningInternal;
        } catch (Throwable th) {
            Config.popLayer();
            throw th;
        }
    }

    protected Model runLearningInternal(RuntimeConfig runtimeConfig, RuntimeResult runtimeResult) {
        for (Map.Entry<String, String> entry : runtimeConfig.learn.options.entrySet()) {
            Config.setProperty(entry.getKey(), entry.getValue(), false);
        }
        Model model = new Model();
        Iterator<Rule> it = runtimeConfig.rules.getRules().iterator();
        while (it.hasNext()) {
            model.addRule(it.next());
        }
        Iterator<Rule> it2 = runtimeConfig.learn.rules.getRules().iterator();
        while (it2.hasNext()) {
            model.addRule(it2.next());
        }
        if (model.getRules().size() == 0) {
            throw new RuntimeException("No rules found for learning.");
        }
        log.debug("Model:");
        Iterator<Rule> it3 = model.getRules().iterator();
        while (it3.hasNext()) {
            log.debug("   " + it3.next());
        }
        DataStore initDataStore = initDataStore(runtimeConfig);
        loadData(initDataStore, runtimeConfig, "learn");
        Set<StandardPredicate> closedPredicates = runtimeConfig.getClosedPredicates("learn");
        Partition partition = initDataStore.getPartition(PARTITION_NAME_TARGET);
        Partition partition2 = initDataStore.getPartition(PARTITION_NAME_OBSERVATIONS);
        Partition partition3 = initDataStore.getPartition(PARTITION_NAME_TRUTH);
        Database database = initDataStore.getDatabase(partition, closedPredicates, partition2);
        Database database2 = initDataStore.getDatabase(partition3, initDataStore.getRegisteredPredicates(), new Partition[0]);
        loadData(initDataStore, runtimeConfig, RuntimeConfig.KEY_VALIDATION);
        Set<StandardPredicate> closedPredicates2 = runtimeConfig.getClosedPredicates(RuntimeConfig.KEY_VALIDATION);
        Partition partition4 = initDataStore.getPartition(String.format("%s_%s", RuntimeConfig.KEY_VALIDATION, PARTITION_NAME_TARGET));
        Partition partition5 = initDataStore.getPartition(String.format("%s_%s", RuntimeConfig.KEY_VALIDATION, PARTITION_NAME_OBSERVATIONS));
        Partition partition6 = initDataStore.getPartition(String.format("%s_%s", RuntimeConfig.KEY_VALIDATION, PARTITION_NAME_TRUTH));
        Database database3 = initDataStore.getDatabase(partition4, closedPredicates2, partition5);
        Database database4 = initDataStore.getDatabase(partition6, initDataStore.getRegisteredPredicates(), new Partition[0]);
        EvaluationInstance evaluationInstance = null;
        Iterator<EvaluationInstance> it4 = getEvaluations(runtimeConfig).iterator();
        while (true) {
            if (!it4.hasNext()) {
                break;
            }
            EvaluationInstance next = it4.next();
            if (next.isPrimary()) {
                evaluationInstance = next;
                break;
            }
        }
        WeightLearningApplication wla = WeightLearningApplication.getWLA(RuntimeOptions.LEARN_METHOD.getString(), model.getRules(), database, database2, database3, database4, RuntimeOptions.VALIDATION.getBoolean());
        wla.setEvaluation(evaluationInstance);
        wla.learn();
        wla.close();
        database.close();
        database2.close();
        database3.close();
        database4.close();
        initDataStore.close();
        log.info("Learned Model:");
        for (Rule rule : model.getRules()) {
            log.info("   " + rule);
            if (runtimeResult != null) {
                runtimeResult.addRule(rule);
            }
        }
        String string = RuntimeOptions.LEARN_OUTPUT_MODEL_PATH.getString();
        if (string != null) {
            log.debug("Writing learned model to {}.", string);
            String replaceAll = model.asString().replaceAll("\\( | \\)", "");
            try {
                BufferedWriter bufferedWriter = FileUtils.getBufferedWriter(string);
                Throwable th = null;
                try {
                    try {
                        bufferedWriter.write(replaceAll);
                        if (bufferedWriter != null) {
                            if (0 != 0) {
                                try {
                                    bufferedWriter.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                bufferedWriter.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (IOException e) {
                log.error("Failed to write learned model:" + System.lineSeparator() + replaceAll);
                throw new RuntimeException("Failed to write learned model to: " + string, e);
            }
        }
        return model;
    }

    protected List<EvaluationInstance> getEvaluations(RuntimeConfig runtimeConfig) {
        boolean z = false;
        ArrayList arrayList = new ArrayList();
        for (RuntimeConfig.PredicateConfigInfo predicateConfigInfo : runtimeConfig.predicates.values()) {
            if (predicateConfigInfo.evaluations.size() != 0) {
                Predicate predicate = Predicate.get(predicateConfigInfo.name);
                if (predicate instanceof StandardPredicate) {
                    for (RuntimeConfig.EvalInfo evalInfo : predicateConfigInfo.evaluations) {
                        Config.pushLayer();
                        try {
                            for (Map.Entry<String, String> entry : evalInfo.options.entrySet()) {
                                Config.setProperty(entry.getKey(), entry.getValue(), false);
                            }
                            Evaluator evaluator = (Evaluator) Reflection.newObject(evalInfo.evaluator);
                            Config.popLayer();
                            arrayList.add(new EvaluationInstance((StandardPredicate) predicate, evaluator, evalInfo.primary));
                            z |= evalInfo.primary;
                        } catch (Throwable th) {
                            Config.popLayer();
                            throw th;
                        }
                    }
                } else {
                    continue;
                }
            }
        }
        if (arrayList.size() == 0) {
            return arrayList;
        }
        if (!z) {
            if (arrayList.size() > 1) {
                log.info("Multiple evaluations declared, but no primary evaluation specified. Using the first evaluation instance: {}.", arrayList.get(0));
            }
            ((EvaluationInstance) arrayList.get(0)).setPrimary(true);
        }
        return arrayList;
    }

    public static void main(String[] strArr) {
        if (strArr == null || strArr.length != 1) {
            System.out.println("USAGE: " + Runtime.class + " <path to JSON config>");
        } else {
            System.out.println(new Runtime().run(strArr[0], true));
        }
    }
}
