package org.linqs.psl.application.inference.online;

import java.util.List;
import java.util.Set;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.inference.online.messages.OnlineMessage;
import org.linqs.psl.application.inference.online.messages.actions.controls.Stop;
import org.linqs.psl.application.inference.online.messages.actions.controls.Sync;
import org.linqs.psl.application.inference.online.messages.actions.controls.WriteInferredPredicates;
import org.linqs.psl.application.inference.online.messages.actions.model.AddAtom;
import org.linqs.psl.application.inference.online.messages.actions.model.DeleteAtom;
import org.linqs.psl.application.inference.online.messages.actions.model.GetAtom;
import org.linqs.psl.application.inference.online.messages.actions.model.ObserveAtom;
import org.linqs.psl.application.inference.online.messages.actions.model.UpdateObservation;
import org.linqs.psl.application.inference.online.messages.actions.template.ActivateRule;
import org.linqs.psl.application.inference.online.messages.actions.template.AddRule;
import org.linqs.psl.application.inference.online.messages.actions.template.DeactivateRule;
import org.linqs.psl.application.inference.online.messages.actions.template.DeleteRule;
import org.linqs.psl.application.inference.online.messages.responses.ActionStatus;
import org.linqs.psl.application.inference.online.messages.responses.GetAtomResponse;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.OnlineAtomManager;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.reasoner.term.online.OnlineTermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.StringUtils;

/* loaded from: input_file:org/linqs/psl/application/inference/online/OnlineInference.class */
public abstract class OnlineInference extends InferenceApplication {
    private static final Logger log = Logger.getLogger(OnlineInference.class);
    private OnlineServer server;
    private boolean modelUpdates;
    private boolean stopped;
    private double objective;
    private List<Evaluator> evaluators;
    private TrainingMap trainingMap;
    private Set<StandardPredicate> evaluationPredicates;

    protected OnlineInference(List<Rule> list, Database database) {
        super(list, database);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public OnlineInference(List<Rule> list, Database database, boolean z) {
        super(list, database, z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.inference.InferenceApplication
    public void initialize() {
        this.stopped = false;
        this.modelUpdates = true;
        this.objective = 0.0d;
        this.evaluators = null;
        this.trainingMap = null;
        this.evaluationPredicates = null;
        startServer();
        super.initialize();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount() + this.atomManager.getCachedObsCount());
    }

    @Override // org.linqs.psl.application.inference.InferenceApplication
    protected PersistedAtomManager createAtomManager(Database database) {
        return new OnlineAtomManager(database, this.initialValue);
    }

    @Override // org.linqs.psl.application.inference.InferenceApplication, org.linqs.psl.application.ModelApplication
    public void close() {
        this.stopped = true;
        closeServer();
        super.close();
    }

    private void closeServer() {
        if (this.server != null) {
            this.server.close();
            this.server = null;
        }
    }

    private void startServer() {
        this.server = new OnlineServer(this.rules);
        this.server.start();
    }

    protected void executeAction(OnlineMessage onlineMessage) {
        String doStop;
        if (onlineMessage.getClass() == AddAtom.class) {
            doStop = doAddAtom((AddAtom) onlineMessage);
        } else if (onlineMessage.getClass() == DeleteAtom.class) {
            doStop = doDeleteAtom((DeleteAtom) onlineMessage);
        } else if (onlineMessage.getClass() == ObserveAtom.class) {
            doStop = doObserveAtom((ObserveAtom) onlineMessage);
        } else if (onlineMessage.getClass() == UpdateObservation.class) {
            doStop = doUpdateObservation((UpdateObservation) onlineMessage);
        } else if (onlineMessage.getClass() == GetAtom.class) {
            doStop = doGetAtom((GetAtom) onlineMessage);
        } else if (onlineMessage.getClass() == ActivateRule.class) {
            doStop = doActivateRule((ActivateRule) onlineMessage);
        } else if (onlineMessage.getClass() == AddRule.class) {
            doStop = doAddRule((AddRule) onlineMessage);
        } else if (onlineMessage.getClass() == DeactivateRule.class) {
            doStop = doDeactivateRule((DeactivateRule) onlineMessage);
        } else if (onlineMessage.getClass() == DeleteRule.class) {
            doStop = doDeleteRule((DeleteRule) onlineMessage);
        } else if (onlineMessage.getClass() == WriteInferredPredicates.class) {
            doStop = doWriteInferredPredicates((WriteInferredPredicates) onlineMessage);
        } else if (onlineMessage.getClass() == Sync.class) {
            doStop = doSync();
        } else {
            if (onlineMessage.getClass() != Stop.class) {
                throw new IllegalArgumentException("Unsupported action: " + onlineMessage.getClass().getName() + ".");
            }
            doStop = doStop();
        }
        this.server.onActionExecution(onlineMessage, new ActionStatus(onlineMessage, true, doStop));
    }

    protected String doAddAtom(AddAtom addAtom) {
        GroundAtom addRandomVariableAtom;
        if (this.atomManager.getDatabase().hasAtom(addAtom.getPredicate(), addAtom.getArguments())) {
            ((OnlineTermStore) this.termStore).deleteLocalVariable(deleteAtom(addAtom.getPredicate(), addAtom.getArguments()));
        }
        if (addAtom.getPartitionName().equalsIgnoreCase("READ")) {
            addRandomVariableAtom = ((OnlineAtomManager) this.atomManager).addObservedAtom(addAtom.getPredicate(), addAtom.getValue(), addAtom.getArguments());
        } else {
            addRandomVariableAtom = ((OnlineAtomManager) this.atomManager).addRandomVariableAtom(addAtom.getPredicate(), addAtom.getValue(), addAtom.getArguments());
            if (this.trainingMap != null) {
                this.trainingMap.addRandomVariableTargetAtom((RandomVariableAtom) addRandomVariableAtom);
            }
        }
        ((OnlineTermStore) this.termStore).createLocalVariable(addRandomVariableAtom);
        this.modelUpdates = true;
        return String.format("Added atom: %s", addRandomVariableAtom.toStringWithValue());
    }

    protected String doDeleteAtom(DeleteAtom deleteAtom) {
        if (!this.atomManager.getDatabase().hasAtom(deleteAtom.getPredicate(), deleteAtom.getArguments())) {
            return String.format("Atom: %s(%s) does not exist in atom manager.", deleteAtom.getPredicate(), StringUtils.join(", ", deleteAtom.getArguments()));
        }
        GroundAtom deleteAtom2 = deleteAtom(deleteAtom.getPredicate(), deleteAtom.getArguments());
        ((OnlineTermStore) this.termStore).deleteLocalVariable(deleteAtom2);
        this.modelUpdates = true;
        return String.format("Deleted atom: %s", deleteAtom2);
    }

    protected String doObserveAtom(ObserveAtom observeAtom) {
        if (!this.atomManager.getDatabase().hasAtom(observeAtom.getPredicate(), observeAtom.getArguments())) {
            return String.format("Atom: %s(%s) does not exist in atom manager.", observeAtom.getPredicate(), StringUtils.join(", ", observeAtom.getArguments()));
        }
        GroundAtom atom = this.atomManager.getAtom(observeAtom.getPredicate(), observeAtom.getArguments());
        if (!(atom instanceof RandomVariableAtom)) {
            return String.format("Atom: %s(%s) already observed.", observeAtom.getPredicate(), StringUtils.join(", ", observeAtom.getArguments()));
        }
        deleteAtom(observeAtom.getPredicate(), observeAtom.getArguments());
        ObservedAtom addObservedAtom = ((OnlineAtomManager) this.atomManager).addObservedAtom(observeAtom.getPredicate(), observeAtom.getValue(), false, observeAtom.getArguments());
        ((OnlineTermStore) this.termStore).updateLocalVariable(addObservedAtom, observeAtom.getValue());
        this.modelUpdates = true;
        return String.format("Observed atom: %s => %s", atom.toStringWithValue(), addObservedAtom.toStringWithValue());
    }

    protected String doUpdateObservation(UpdateObservation updateObservation) {
        if (!this.atomManager.getDatabase().hasAtom(updateObservation.getPredicate(), updateObservation.getArguments())) {
            return String.format("Atom: %s(%s) does not exist in atom manager.", updateObservation.getPredicate(), StringUtils.join(", ", updateObservation.getArguments()));
        }
        GroundAtom atom = this.atomManager.getAtom(updateObservation.getPredicate(), updateObservation.getArguments());
        if (!(atom instanceof ObservedAtom)) {
            return String.format("Atom: %s is not an observation.", atom);
        }
        float value = atom.getValue();
        ((OnlineTermStore) this.termStore).updateLocalVariable((ObservedAtom) atom, updateObservation.getValue());
        ((ObservedAtom) atom)._assumeValue(updateObservation.getValue());
        this.modelUpdates = true;
        return String.format("Updated atom: %s: %f => %f", atom, Float.valueOf(value), Float.valueOf(atom.getValue()));
    }

    protected String doGetAtom(GetAtom getAtom) {
        if (!((OnlineAtomManager) this.atomManager).hasAtom(getAtom.getPredicate(), getAtom.getArguments())) {
            this.server.onActionExecution(getAtom, new GetAtomResponse(getAtom, -1.0d));
            return String.format("Atom: %s(%s) not found.", getAtom.getPredicate(), StringUtils.join(", ", getAtom.getArguments()));
        }
        optimize();
        this.server.onActionExecution(getAtom, new GetAtomResponse(getAtom, this.atomManager.getAtom(getAtom.getPredicate(), getAtom.getArguments()).getValue()));
        return String.format("Atom: %s(%s) found. Returned to client.", getAtom.getPredicate(), StringUtils.join(", ", getAtom.getArguments()));
    }

    protected String doActivateRule(ActivateRule activateRule) {
        if (activateRule.isNewRule()) {
            return String.format("Rule: %s does not exist in model.", activateRule.getRule());
        }
        ((OnlineTermStore) this.termStore).activateRule(activateRule.getRule());
        this.modelUpdates = true;
        return String.format("Activated rule: %s", activateRule.getRule());
    }

    protected String doAddRule(AddRule addRule) {
        if (!addRule.isNewRule()) {
            return String.format("Rule: %s already exists in model.", addRule.getRule());
        }
        ((OnlineTermStore) this.termStore).addRule(addRule.getRule());
        this.modelUpdates = true;
        return String.format("Added rule: %s", addRule.getRule());
    }

    protected String doDeactivateRule(DeactivateRule deactivateRule) {
        if (deactivateRule.isNewRule()) {
            return String.format("Rule: %s does not exist in model.", deactivateRule.getRule());
        }
        ((OnlineTermStore) this.termStore).deactivateRule(deactivateRule.getRule());
        this.modelUpdates = true;
        return String.format("Deactivated rule: %s", deactivateRule.getRule());
    }

    protected String doDeleteRule(DeleteRule deleteRule) {
        if (deleteRule.isNewRule()) {
            return String.format("Rule: %s does not exist in model.", deleteRule.getRule());
        }
        ((OnlineTermStore) this.termStore).deleteRule(deleteRule.getRule());
        deleteRule.getRule().unregister();
        this.modelUpdates = true;
        return String.format("Deleted rule: %s", deleteRule.getRule());
    }

    protected String doWriteInferredPredicates(WriteInferredPredicates writeInferredPredicates) {
        String str;
        optimize();
        if (writeInferredPredicates.getOutputDirectoryPath() != null) {
            log.info("Writing inferred predicates to file: " + writeInferredPredicates.getOutputDirectoryPath());
            this.database.outputRandomVariableAtoms(writeInferredPredicates.getOutputDirectoryPath());
            str = "Wrote inferred predicates to file: " + writeInferredPredicates.getOutputDirectoryPath();
        } else {
            log.info("Writing inferred predicates to output stream.");
            this.database.outputRandomVariableAtoms();
            str = "Wrote inferred predicates to output stream.";
        }
        return str;
    }

    protected String doSync() {
        optimize();
        return "OnlinePSL inference synced.";
    }

    protected String doStop() {
        this.stopped = true;
        return "OnlinePSL inference stopped.";
    }

    private GroundAtom deleteAtom(StandardPredicate standardPredicate, Constant[] constantArr) {
        GroundAtom deleteAtom = ((OnlineAtomManager) this.atomManager).deleteAtom(standardPredicate, constantArr);
        if (deleteAtom == null) {
            return null;
        }
        if (this.trainingMap != null) {
            this.trainingMap.deleteAtom(deleteAtom);
        }
        return deleteAtom;
    }

    private void optimize() {
        if (this.modelUpdates) {
            log.trace("Optimization Start");
            this.objective = this.reasoner.optimize(this.termStore, this.evaluators, this.trainingMap, this.evaluationPredicates);
            log.trace("Optimization End");
            this.modelUpdates = false;
        }
    }

    @Override // org.linqs.psl.application.inference.InferenceApplication
    public double internalInference(List<Evaluator> list, TrainingMap trainingMap, Set<StandardPredicate> set) {
        this.evaluators = list;
        this.trainingMap = trainingMap;
        this.evaluationPredicates = set;
        optimize();
        while (!this.stopped) {
            OnlineMessage action = this.server.getAction();
            if (action != null) {
                try {
                    log.trace(String.format("Executing action: %s", action));
                    executeAction(action);
                } catch (IllegalArgumentException e) {
                    this.server.onActionExecution(action, new ActionStatus(action, false, e.getMessage()));
                } catch (RuntimeException e2) {
                    this.server.onActionExecution(action, new ActionStatus(action, false, e2.getMessage()));
                    throw new RuntimeException(String.format("Critically failed to execute action: %s", action), e2);
                }
            }
        }
        closeServer();
        return this.objective;
    }
}
