package com.github.thorbenlindhauer.cluster.messagepassing;

import com.github.thorbenlindhauer.cluster.Cluster;
import com.github.thorbenlindhauer.cluster.ClusterGraph;
import com.github.thorbenlindhauer.cluster.Edge;
import com.github.thorbenlindhauer.exception.InferenceException;
import com.github.thorbenlindhauer.exception.ModelStructureException;
import com.github.thorbenlindhauer.factor.Factor;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:com/github/thorbenlindhauer/cluster/messagepassing/AbstractMessagePassingContext.class */
public abstract class AbstractMessagePassingContext<T extends Factor<T>> implements MessagePassingContext<T> {
    protected Map<Edge<T>, EdgeContext<T>> messages = new HashMap();
    protected Map<Cluster<T>, T> clusterPotentials;

    /* loaded from: input_file:com/github/thorbenlindhauer/cluster/messagepassing/AbstractMessagePassingContext$EdgeContext.class */
    protected static class EdgeContext<T extends Factor<T>> {
        protected Cluster<T> cluster1;
        protected Message<T> message1;
        protected Cluster<T> cluster2;
        protected Message<T> message2;

        public EdgeContext(Edge<T> edge, AbstractMessagePassingContext<T> abstractMessagePassingContext) {
            this.cluster1 = edge.getCluster1();
            this.message1 = abstractMessagePassingContext.newMessage(this.cluster1, edge);
            this.cluster2 = edge.getCluster2();
            this.message2 = abstractMessagePassingContext.newMessage(this.cluster2, edge);
        }

        public Message<T> getMessageFrom(Cluster<T> cluster) {
            if (cluster == this.cluster1) {
                return this.message1;
            }
            if (cluster == this.cluster2) {
                return this.message2;
            }
            throw new ModelStructureException("");
        }
    }

    public AbstractMessagePassingContext(ClusterGraph<T> clusterGraph) {
        for (Edge<T> edge : clusterGraph.getEdges()) {
            this.messages.put(edge, new EdgeContext<>(edge, this));
        }
        this.clusterPotentials = new HashMap();
    }

    @Override // com.github.thorbenlindhauer.cluster.messagepassing.MessagePassingContext
    public T getClusterPotential(Cluster<T> cluster) {
        ensurePotentialInitialized(cluster);
        return this.clusterPotentials.get(cluster);
    }

    protected void ensurePotentialInitialized(Cluster<T> cluster) {
        if (this.clusterPotentials.get(cluster) == null) {
            this.clusterPotentials.put(cluster, calculateClusterPotential(cluster));
        }
    }

    protected abstract T calculateClusterPotential(Cluster<T> cluster);

    @Override // com.github.thorbenlindhauer.cluster.messagepassing.MessagePassingContext
    public Message<T> getMessage(Edge<T> edge, Cluster<T> cluster) {
        EdgeContext<T> edgeContext = this.messages.get(edge);
        if (edgeContext == null) {
            throw new InferenceException("Edge " + edge + " is not known to this context");
        }
        return edgeContext.getMessageFrom(cluster);
    }

    protected abstract Message<T> newMessage(Cluster<T> cluster, Edge<T> edge);

    @Override // com.github.thorbenlindhauer.Listener
    public void notify(String str, Message<T> message) {
        if (MessageListener.UPDATE_EVENT.equals(str)) {
            this.clusterPotentials.put(message.getTargetCluster(), null);
        }
    }
}
