package org.linqs.psl.cli;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Iterator;
import java.util.Set;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.configuration2.ex.ConfigurationException;
import org.apache.commons.configuration2.tree.DefaultExpressionEngineSymbols;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.database.DataStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.Partition;
import org.linqs.psl.database.loading.Inserter;
import org.linqs.psl.database.rdbms.RDBMSDataStore;
import org.linqs.psl.database.rdbms.driver.H2DatabaseDriver;
import org.linqs.psl.database.rdbms.driver.PostgreSQLDriver;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.UnweightedGroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.parser.CommandLineLoader;
import org.linqs.psl.parser.ModelLoader;
import org.linqs.psl.util.Reflection;
import org.linqs.psl.util.StringUtils;
import org.linqs.psl.util.Version;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/cli/Launcher.class */
public class Launcher {
    public static final String MODEL_FILE_EXTENSION = ".psl";
    public static final String PARTITION_NAME_OBSERVATIONS = "observations";
    public static final String PARTITION_NAME_TARGET = "targets";
    public static final String PARTITION_NAME_LABELS = "truth";
    private static final Logger log = LoggerFactory.getLogger((Class<?>) Launcher.class);
    private CommandLine parsedOptions;

    private Launcher(CommandLine commandLine) {
        this.parsedOptions = commandLine;
    }

    private DataStore initDataStore() {
        String str = CommandLineLoader.DEFAULT_H2_DB_PATH;
        boolean z = true;
        if (this.parsedOptions.hasOption(CommandLineLoader.OPTION_DB_H2_PATH)) {
            str = this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_DB_H2_PATH);
        } else if (this.parsedOptions.hasOption("postgres")) {
            str = this.parsedOptions.getOptionValue("postgres", CommandLineLoader.DEFAULT_POSTGRES_DB_NAME);
            z = false;
        }
        return new RDBMSDataStore(z ? new H2DatabaseDriver(H2DatabaseDriver.Type.Disk, str, true) : new PostgreSQLDriver(str, true));
    }

    private Set<StandardPredicate> loadData(DataStore dataStore) {
        log.info("Loading data");
        try {
            Set<StandardPredicate> load = DataLoader.load(dataStore, this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_DATA), this.parsedOptions.hasOption(CommandLineLoader.OPTION_INT_IDS));
            log.info("Data loading complete");
            return load;
        } catch (FileNotFoundException | ConfigurationException e) {
            throw new RuntimeException("Failed to load data.", e);
        }
    }

    private void outputGroundRules(GroundRuleStore groundRuleStore, String str, boolean z) {
        String join;
        double infeasibility;
        PrintStream printStream = System.out;
        boolean z2 = false;
        if (str != null) {
            try {
                printStream = new PrintStream(str);
                z2 = true;
            } catch (IOException e) {
                log.error(String.format("Unable to open file (%s) for ground rules, using stdout instead.", str), (Throwable) e);
            }
        }
        String join2 = StringUtils.join(Inserter.DEFAULT_DELIMITER, "Weight", "Squared?", "Rule");
        if (z) {
            join2 = StringUtils.join(Inserter.DEFAULT_DELIMITER, join2, "Satisfaction");
        }
        printStream.println(join2);
        for (GroundRule groundRule : groundRuleStore.getGroundRules()) {
            if (groundRule instanceof WeightedGroundRule) {
                WeightedGroundRule weightedGroundRule = (WeightedGroundRule) groundRule;
                join = StringUtils.join(Inserter.DEFAULT_DELIMITER, "" + weightedGroundRule.getWeight(), "" + weightedGroundRule.isSquared(), groundRule.baseToString());
                infeasibility = 1.0d - weightedGroundRule.getIncompatibility();
            } else {
                join = StringUtils.join(Inserter.DEFAULT_DELIMITER, DefaultExpressionEngineSymbols.DEFAULT_PROPERTY_DELIMITER, "false", groundRule.baseToString());
                infeasibility = 1.0d - ((UnweightedGroundRule) groundRule).getInfeasibility();
            }
            if (z) {
                join = StringUtils.join(Inserter.DEFAULT_DELIMITER, join, "" + infeasibility);
            }
            printStream.println(join);
        }
        if (z2) {
            printStream.close();
        }
    }

    private Database runInference(Model model, DataStore dataStore, Set<StandardPredicate> set, String str) {
        log.info("Starting inference with class: {}", str);
        Database database = dataStore.getDatabase(dataStore.getPartition("targets"), set, dataStore.getPartition("observations"));
        InferenceApplication inferenceApplication = InferenceApplication.getInferenceApplication(str, model, database);
        if (this.parsedOptions.hasOption(CommandLineLoader.OPTION_OUTPUT_GROUND_RULES_LONG)) {
            outputGroundRules(inferenceApplication.getGroundRuleStore(), this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_OUTPUT_GROUND_RULES_LONG), false);
        }
        inferenceApplication.inference(!this.parsedOptions.hasOption(CommandLineLoader.OPTION_SKIP_ATOM_COMMIT_LONG));
        if (this.parsedOptions.hasOption(CommandLineLoader.OPTION_OUTPUT_SATISFACTION_LONG)) {
            outputGroundRules(inferenceApplication.getGroundRuleStore(), this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_OUTPUT_SATISFACTION_LONG), true);
        }
        log.info("Inference Complete");
        outputResults(database, dataStore, set);
        return database;
    }

    private void outputResults(Database database, DataStore dataStore, Set<StandardPredicate> set) {
        Set<StandardPredicate> registeredPredicates = dataStore.getRegisteredPredicates();
        registeredPredicates.removeAll(set);
        if (!this.parsedOptions.hasOption(CommandLineLoader.OPTION_OUTPUT_DIR)) {
            Iterator<StandardPredicate> it = registeredPredicates.iterator();
            while (it.hasNext()) {
                for (RandomVariableAtom randomVariableAtom : database.getAllGroundRandomVariableAtoms(it.next())) {
                    System.out.println(randomVariableAtom.toString() + " = " + randomVariableAtom.getValue());
                }
            }
            return;
        }
        File file = new File(this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_OUTPUT_DIR));
        file.mkdirs();
        for (StandardPredicate standardPredicate : registeredPredicates) {
            try {
                FileWriter fileWriter = new FileWriter(new File(file, standardPredicate.getName() + ".txt"));
                StringBuilder sb = new StringBuilder();
                for (RandomVariableAtom randomVariableAtom2 : database.getAllGroundRandomVariableAtoms(standardPredicate)) {
                    sb.setLength(0);
                    for (Constant constant : randomVariableAtom2.getArguments()) {
                        sb.append(constant.rawToString());
                        sb.append(Inserter.DEFAULT_DELIMITER);
                    }
                    sb.append(Double.toString(randomVariableAtom2.getValue()));
                    sb.append(org.apache.commons.lang3.StringUtils.LF);
                    fileWriter.write(sb.toString());
                }
                fileWriter.close();
            } catch (IOException e) {
                log.error("Exception writing predicate {}", standardPredicate);
            }
        }
    }

    private void learnWeights(Model model, DataStore dataStore, Set<StandardPredicate> set, String str) {
        log.info("Starting weight learning with learner: " + str);
        Partition partition = dataStore.getPartition("targets");
        Partition partition2 = dataStore.getPartition("observations");
        Partition partition3 = dataStore.getPartition("truth");
        Database database = dataStore.getDatabase(partition, set, partition2);
        Database database2 = dataStore.getDatabase(partition3, dataStore.getRegisteredPredicates(), new Partition[0]);
        WeightLearningApplication wla = WeightLearningApplication.getWLA(str, model.getRules(), database, database2);
        wla.learn();
        if (this.parsedOptions.hasOption(CommandLineLoader.OPTION_OUTPUT_GROUND_RULES_LONG)) {
            outputGroundRules(wla.getGroundRuleStore(), this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_OUTPUT_GROUND_RULES_LONG), false);
        }
        wla.close();
        if (this.parsedOptions.hasOption(CommandLineLoader.OPTION_OUTPUT_SATISFACTION_LONG)) {
            outputGroundRules(wla.getGroundRuleStore(), this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_OUTPUT_SATISFACTION_LONG), true);
        }
        database.close();
        database2.close();
        log.info("Weight learning complete");
        String optionValue = this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_MODEL);
        int lastIndexOf = optionValue.lastIndexOf(MODEL_FILE_EXTENSION);
        String str2 = lastIndexOf == -1 ? optionValue + MODEL_FILE_EXTENSION : optionValue.substring(0, lastIndexOf) + "-learned" + MODEL_FILE_EXTENSION;
        log.info("Writing learned model to {}", str2);
        String replaceAll = model.asString().replaceAll("\\( | \\)", "");
        try {
            FileWriter fileWriter = new FileWriter(new File(str2));
            Throwable th = null;
            try {
                fileWriter.write(replaceAll);
                if (fileWriter != null) {
                    if (0 != 0) {
                        try {
                            fileWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        fileWriter.close();
                    }
                }
            } finally {
            }
        } catch (IOException e) {
            log.error("Failed to write learned model:\n" + replaceAll);
            throw new RuntimeException("Failed to write learned model to: " + str2, e);
        }
    }

    private void evaluation(DataStore dataStore, Database database, Set<StandardPredicate> set, String str) {
        log.info("Starting evaluation with class: {}.", str);
        Set<StandardPredicate> registeredPredicates = dataStore.getRegisteredPredicates();
        registeredPredicates.removeAll(set);
        Partition partition = dataStore.getPartition("targets");
        Partition partition2 = dataStore.getPartition("observations");
        Partition partition3 = dataStore.getPartition("truth");
        boolean z = false;
        if (database == null) {
            z = true;
            database = dataStore.getDatabase(partition, set, partition2);
        }
        Database database2 = dataStore.getDatabase(partition3, dataStore.getRegisteredPredicates(), new Partition[0]);
        Evaluator evaluator = (Evaluator) Reflection.newObject(str);
        for (StandardPredicate standardPredicate : registeredPredicates) {
            if (database2.countAllGroundAtoms(standardPredicate) == 0) {
                log.info("Skipping evaluation for {} since there are no ground truth atoms", standardPredicate);
            } else {
                evaluator.compute(database, database2, standardPredicate, !z);
                log.info("Evaluation results for {} -- {}", standardPredicate.getName(), evaluator.getAllStats());
            }
        }
        if (z) {
            database.close();
        }
        database2.close();
    }

    private Model loadModel(DataStore dataStore) {
        log.info("Loading model from {}", this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_MODEL));
        try {
            FileReader fileReader = new FileReader(new File(this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_MODEL)));
            Throwable th = null;
            try {
                Model load = ModelLoader.load(dataStore, fileReader);
                if (fileReader != null) {
                    if (0 != 0) {
                        try {
                            fileReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        fileReader.close();
                    }
                }
                log.debug("Model:");
                Iterator<Rule> it = load.getRules().iterator();
                while (it.hasNext()) {
                    log.debug("   " + it.next());
                }
                log.info("Model loading complete");
                return load;
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("Failed to load model from file: " + this.parsedOptions.getOptionValue(CommandLineLoader.OPTION_MODEL), e);
        }
    }

    private void run() {
        log.info("Running PSL CLI Version {}", Version.getFull());
        DataStore initDataStore = initDataStore();
        Set<StandardPredicate> loadData = loadData(initDataStore);
        Model loadModel = loadModel(initDataStore);
        Database database = null;
        if (this.parsedOptions.hasOption(CommandLineLoader.OPERATION_INFER)) {
            database = runInference(loadModel, initDataStore, loadData, this.parsedOptions.getOptionValue(CommandLineLoader.OPERATION_INFER, CommandLineLoader.DEFAULT_IA));
        } else {
            if (!this.parsedOptions.hasOption(CommandLineLoader.OPERATION_LEARN)) {
                throw new IllegalArgumentException("No valid operation provided.");
            }
            learnWeights(loadModel, initDataStore, loadData, this.parsedOptions.getOptionValue(CommandLineLoader.OPERATION_LEARN, CommandLineLoader.DEFAULT_WLA));
        }
        if (this.parsedOptions.hasOption(CommandLineLoader.OPTION_EVAL)) {
            for (String str : this.parsedOptions.getOptionValues(CommandLineLoader.OPTION_EVAL)) {
                evaluation(initDataStore, database, loadData, str);
            }
            log.info("Evaluation complete.");
        }
        if (database != null) {
            database.close();
        }
        initDataStore.close();
    }

    private static boolean isCommandLineValid(CommandLine commandLine) {
        if (commandLine.hasOption(CommandLineLoader.OPTION_HELP) || commandLine.hasOption(CommandLineLoader.OPTION_VERSION)) {
            return false;
        }
        HelpFormatter helpFormatter = new HelpFormatter();
        if (!commandLine.hasOption(CommandLineLoader.OPTION_DATA)) {
            System.out.println(String.format("Missing required option: --%s/-%s.", CommandLineLoader.OPTION_DATA_LONG, CommandLineLoader.OPTION_DATA));
            helpFormatter.printHelp("psl", CommandLineLoader.getOptions(), true);
            return false;
        }
        if (!commandLine.hasOption(CommandLineLoader.OPTION_MODEL)) {
            System.out.println(String.format("Missing required option: --%s/-%s.", CommandLineLoader.OPTION_MODEL_LONG, CommandLineLoader.OPTION_MODEL));
            helpFormatter.printHelp("psl", CommandLineLoader.getOptions(), true);
            return false;
        }
        if (commandLine.hasOption(CommandLineLoader.OPERATION_INFER) || commandLine.hasOption(CommandLineLoader.OPERATION_LEARN)) {
            return true;
        }
        System.out.println(String.format("Missing required option: --%s/-%s.", CommandLineLoader.OPERATION_INFER_LONG, CommandLineLoader.OPERATION_INFER));
        helpFormatter.printHelp("psl", CommandLineLoader.getOptions(), true);
        return false;
    }

    public static void main(String[] strArr) {
        main(strArr, false);
    }

    public static void main(String[] strArr, boolean z) {
        try {
            CommandLine parsedOptions = new CommandLineLoader(strArr).getParsedOptions();
            if (parsedOptions == null || !isCommandLineValid(parsedOptions)) {
                return;
            }
            new Launcher(parsedOptions).run();
        } catch (Exception e) {
            if (z) {
                throw new RuntimeException("Failed to run CLI: " + e.getMessage(), e);
            }
            System.err.println("Unexpected exception!");
            e.printStackTrace(System.err);
            System.exit(1);
        }
    }
}
