package com.github.thorbenlindhauer.inference;

import com.github.thorbenlindhauer.cluster.Cluster;
import com.github.thorbenlindhauer.cluster.ClusterGraph;
import com.github.thorbenlindhauer.cluster.messagepassing.Message;
import com.github.thorbenlindhauer.cluster.messagepassing.MessageListener;
import com.github.thorbenlindhauer.cluster.messagepassing.MessagePassingContext;
import com.github.thorbenlindhauer.cluster.messagepassing.MessagePassingContextFactory;
import com.github.thorbenlindhauer.exception.ModelStructureException;
import com.github.thorbenlindhauer.factor.Factor;
import com.github.thorbenlindhauer.inference.loopy.ClusterGraphCalibrationContext;
import com.github.thorbenlindhauer.inference.loopy.ClusterGraphCalibrationContextFactory;
import com.github.thorbenlindhauer.variable.Scope;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/github/thorbenlindhauer/inference/AbstractClusterGraphInferencer.class */
public abstract class AbstractClusterGraphInferencer<T extends Factor<T>> {
    protected static final int MAX_ITERATIONS_PER_EDGE = 10;
    protected ClusterGraph<T> clusterGraph;
    protected MessagePassingContext<T> messagePassingContext;
    protected ClusterGraphCalibrationContext<T> calibrationContext;
    protected boolean messagesPropagated = false;
    protected List<MessageListener<T>> messagePassingListeners = new ArrayList();

    public AbstractClusterGraphInferencer(ClusterGraph<T> clusterGraph, MessagePassingContextFactory messagePassingContextFactory, ClusterGraphCalibrationContextFactory<T> clusterGraphCalibrationContextFactory) {
        this.clusterGraph = clusterGraph;
        this.messagePassingContext = messagePassingContextFactory.newMessagePassingContext(clusterGraph);
        this.calibrationContext = clusterGraphCalibrationContextFactory.buildCalibrationContext(clusterGraph, this.messagePassingContext);
        this.messagePassingListeners.add(this.messagePassingContext);
        this.messagePassingListeners.add(this.calibrationContext);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public T getClusterFactorContainingScope(Scope scope) {
        ensureMessagesPropagated();
        for (Cluster<T> cluster : this.clusterGraph.getClusters()) {
            if (cluster.getScope().contains(scope)) {
                return (T) this.messagePassingContext.getClusterPotential(cluster).marginal2(scope);
            }
        }
        throw new ModelStructureException("There is no cluster that contains scope " + scope + " entirely and queries spanning multiple clusters are not yet implemented");
    }

    protected void ensureMessagesPropagated() {
        if (this.messagesPropagated) {
            return;
        }
        propagateMessages();
    }

    protected void propagateMessages() {
        Message<T> nextUncalibratedMessage = this.calibrationContext.getNextUncalibratedMessage();
        for (int i = 0; nextUncalibratedMessage != null && i < MAX_ITERATIONS_PER_EDGE * this.clusterGraph.getEdges().size(); i++) {
            nextUncalibratedMessage.update(this.messagePassingContext);
            notifyListeners(MessageListener.UPDATE_EVENT, nextUncalibratedMessage);
            nextUncalibratedMessage = this.calibrationContext.getNextUncalibratedMessage();
        }
        this.messagesPropagated = true;
    }

    protected void notifyListeners(String str, Message<T> message) {
        Iterator<MessageListener<T>> it = this.messagePassingListeners.iterator();
        while (it.hasNext()) {
            it.next().notify(str, message);
        }
    }

    public void addMessageListener(MessageListener<T> messageListener) {
        this.messagePassingListeners.add(messageListener);
    }
}
