package com.github.thorbenlindhauer.inference;

import com.github.thorbenlindhauer.cluster.Cluster;
import com.github.thorbenlindhauer.cluster.ClusterGraph;
import com.github.thorbenlindhauer.cluster.Edge;
import com.github.thorbenlindhauer.cluster.messagepassing.Message;
import com.github.thorbenlindhauer.cluster.messagepassing.MessagePassingContext;
import com.github.thorbenlindhauer.cluster.messagepassing.MessagePassingContextFactory;
import com.github.thorbenlindhauer.exception.ModelStructureException;
import com.github.thorbenlindhauer.factor.DiscreteFactor;
import com.github.thorbenlindhauer.variable.Scope;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:com/github/thorbenlindhauer/inference/CliqueTreeInferencer.class */
public class CliqueTreeInferencer implements DiscreteModelInferencer {
    protected ClusterGraph<DiscreteFactor> clusterGraph;
    protected Cluster<DiscreteFactor> rootCluster;
    protected boolean messagesPropagated = false;
    protected MessagePassingContext<DiscreteFactor> messagePassingContext;

    public CliqueTreeInferencer(ClusterGraph<DiscreteFactor> clusterGraph, Cluster<DiscreteFactor> cluster, MessagePassingContextFactory messagePassingContextFactory) {
        this.clusterGraph = clusterGraph;
        this.rootCluster = cluster;
        this.messagePassingContext = messagePassingContextFactory.newMessagePassingContext(clusterGraph);
    }

    @Override // com.github.thorbenlindhauer.inference.DiscreteModelInferencer
    public double jointProbability(Scope scope, int[] iArr) {
        return getClusterFactorContainingScope(scope).marginal2(scope).normalize2().getValueForAssignment(iArr);
    }

    @Override // com.github.thorbenlindhauer.inference.DiscreteModelInferencer
    public double jointProbability(Scope scope, int[] iArr, Scope scope2, int[] iArr2) {
        return getClusterFactorContainingScope(scope.union(scope2)).normalize2().observation(scope2, iArr2).marginal2(scope).getValueForAssignment(iArr);
    }

    @Override // com.github.thorbenlindhauer.inference.DiscreteModelInferencer
    public double jointProbabilityConditionedOn(Scope scope, int[] iArr, Scope scope2, int[] iArr2) {
        return getClusterFactorContainingScope(scope.union(scope2)).observation(scope2, iArr2).marginal2(scope).normalize2().getValueForAssignment(iArr);
    }

    protected DiscreteFactor getClusterFactorContainingScope(Scope scope) {
        ensureMessagesPropagated();
        for (Cluster<DiscreteFactor> cluster : this.clusterGraph.getClusters()) {
            if (cluster.getScope().contains(scope)) {
                return this.messagePassingContext.getClusterPotential(cluster).marginal2(scope);
            }
        }
        throw new ModelStructureException("There is no cluster that contains scope " + scope + " entirely and queries spanning multiple clusters are not yet implemented");
    }

    protected void ensureMessagesPropagated() {
        if (this.messagesPropagated) {
            return;
        }
        propagateMessages();
    }

    protected void propagateMessages() {
        HashSet hashSet = new HashSet();
        for (Cluster<DiscreteFactor> cluster : this.clusterGraph.getClusters()) {
            if (cluster != this.rootCluster && cluster.getEdges().size() == 1) {
                hashSet.add(this.messagePassingContext.getMessage(cluster.getEdges().iterator().next(), cluster));
            }
        }
        executeMessagePass(hashSet, true);
        HashSet hashSet2 = new HashSet();
        Iterator<Edge<DiscreteFactor>> it = this.rootCluster.getEdges().iterator();
        while (it.hasNext()) {
            hashSet2.add(this.messagePassingContext.getMessage(it.next(), this.rootCluster));
        }
        executeMessagePass(hashSet2, false);
        this.messagesPropagated = true;
    }

    protected void executeMessagePass(Set<Message<DiscreteFactor>> set, boolean z) {
        HashSet hashSet = new HashSet();
        while (!set.isEmpty()) {
            Iterator<Message<DiscreteFactor>> it = set.iterator();
            Message<DiscreteFactor> next = it.next();
            it.remove();
            next.update(this.messagePassingContext);
            hashSet.add(next.getEdge());
            Cluster<DiscreteFactor> targetCluster = next.getTargetCluster();
            if (targetCluster != this.rootCluster) {
                for (Edge<DiscreteFactor> edge : targetCluster.getOtherEdges(next.getEdge())) {
                    if (!hashSet.contains(edge) && (!z || hashSet.containsAll(targetCluster.getOtherEdges(edge)))) {
                        set.add(this.messagePassingContext.getMessage(edge, targetCluster));
                    }
                }
            }
        }
    }
}
