package gov.sandia.cognition.graph.inference;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;

@PublicationReference(author = {"Jonahtan S. Yedidia, William T. Freeman, and Yair Weiss"}, title = "Understanding Belief Propagation and its Generalizations", type = PublicationType.TechnicalReport, year = 2001, notes = {"Institution: Mitsubishi Electric Research Laboratories"})
/* loaded from: input_file:gov/sandia/cognition/graph/inference/SumProductInferencingAlgorithm.class */
public abstract class SumProductInferencingAlgorithm<LabelType> implements EnergyFunctionSolver<LabelType> {
    public static final double DEFAULT_EPS = 0.001d;
    public static final int DEFAULT_MAX_ITERATIONS = 20;
    public static final int DEFAULT_NUM_THREADS = 4;
    private double eps;
    private int maxNumIterations;
    private int numThreads;
    protected List<Node<LabelType>> nodes;
    protected EnergyFunction<LabelType> fn;
    private ConcurrentLinkedQueue<List<Integer>> edgeGroups;
    private ConcurrentLinkedQueue<List<Node<LabelType>>> nodeGroups;
    private List<List<Integer>> edgeGroupsMaster;
    private List<List<Node<LabelType>>> nodeGroupsMaster;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:gov/sandia/cognition/graph/inference/SumProductInferencingAlgorithm$SolveThread.class */
    public class SolveThread implements Runnable {
        private double delta;
        private SolverSetting setting;

        private SolveThread() {
        }

        @Override // java.lang.Runnable
        public void run() {
            this.delta = 0.0d;
            switch (this.setting) {
                case COMPUTE_MESSAGES:
                    computeMesssages();
                    return;
                case NORMALIZE_NODES:
                    normalizeNodes();
                    return;
                case COMPUTE_BELIEFS:
                    computeBeliefs();
                    return;
                default:
                    throw new RuntimeException("Unhandled case, setting = " + this.setting);
            }
        }

        private void computeMesssages() {
            while (true) {
                List list = (List) SumProductInferencingAlgorithm.this.edgeGroups.poll();
                if (list == null) {
                    return;
                }
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    SumProductInferencingAlgorithm.this.computeTemporaryMessage(((Integer) it.next()).intValue());
                }
            }
        }

        private void normalizeNodes() {
            while (true) {
                List<Node> list = (List) SumProductInferencingAlgorithm.this.nodeGroups.poll();
                if (list == null) {
                    return;
                }
                for (Node node : list) {
                    node.normalizeMessagesForSumProductAlgorithm();
                    this.delta = Math.max(this.delta, node.update());
                }
            }
        }

        private void computeBeliefs() {
            while (true) {
                List list = (List) SumProductInferencingAlgorithm.this.nodeGroups.poll();
                if (list == null) {
                    return;
                }
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    ((Node) it.next()).computeBeliefsForSumProductAlgorithm(SumProductInferencingAlgorithm.this.fn);
                }
            }
        }

        public double getDelta() {
            return this.delta;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:gov/sandia/cognition/graph/inference/SumProductInferencingAlgorithm$SolverSetting.class */
    public enum SolverSetting {
        COMPUTE_MESSAGES,
        NORMALIZE_NODES,
        COMPUTE_BELIEFS
    }

    public SumProductInferencingAlgorithm(int i, double d, int i2) {
        if (!$assertionsDisabled && i <= 0) {
            throw new AssertionError();
        }
        this.maxNumIterations = i;
        this.eps = d;
        this.numThreads = i2;
        this.fn = null;
    }

    public SumProductInferencingAlgorithm(int i) {
        this(i, 0.001d, 4);
    }

    public SumProductInferencingAlgorithm() {
        this(20, 0.001d, 4);
    }

    @Override // gov.sandia.cognition.graph.inference.EnergyFunctionSolver
    public boolean solve() {
        boolean z = false;
        this.edgeGroups.clear();
        this.nodeGroups.clear();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads, new ThreadFactory() { // from class: gov.sandia.cognition.graph.inference.SumProductInferencingAlgorithm.1
            private final String baseName = "BpSolver-";
            private int counter = 0;

            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(Runnable runnable) {
                StringBuilder append = new StringBuilder().append("BpSolver-");
                int i = this.counter;
                this.counter = i + 1;
                return new Thread(runnable, append.append(i).toString());
            }
        });
        ArrayList arrayList = new ArrayList(this.numThreads);
        for (int i = 0; i < this.numThreads; i++) {
            arrayList.add(new SolveThread());
        }
        ArrayList arrayList2 = new ArrayList(this.numThreads);
        for (int i2 = 0; !z && i2 < this.maxNumIterations; i2++) {
            copyFromMasters();
            loadAndStartFutures(arrayList2, newFixedThreadPool, arrayList, SolverSetting.COMPUTE_MESSAGES);
            waitForThreadsToComplete(arrayList2);
            loadAndStartFutures(arrayList2, newFixedThreadPool, arrayList, SolverSetting.NORMALIZE_NODES);
            waitForThreadsToComplete(arrayList2);
            double d = 0.0d;
            for (int i3 = 0; i3 < this.numThreads; i3++) {
                d = Math.max(d, arrayList.get(i3).getDelta());
            }
            if (d < this.eps) {
                z = true;
            }
        }
        copyFromMasters();
        loadAndStartFutures(arrayList2, newFixedThreadPool, arrayList, SolverSetting.COMPUTE_BELIEFS);
        waitForThreadsToComplete(arrayList2);
        newFixedThreadPool.shutdown();
        return z;
    }

    private void loadAndStartFutures(List<Future<?>> list, ExecutorService executorService, List<SumProductInferencingAlgorithm<LabelType>.SolveThread> list2, SolverSetting solverSetting) {
        for (int i = 0; i < this.numThreads; i++) {
            ((SolveThread) list2.get(i)).setting = solverSetting;
            list.add(executorService.submit(list2.get(i)));
        }
    }

    private void waitForThreadsToComplete(List<Future<?>> list) {
        for (int i = 0; i < this.numThreads; i++) {
            try {
                list.get(i).get();
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        }
        list.clear();
    }

    protected abstract void computeTemporaryMessage(int i);

    private void copyFromMasters() {
        if (!this.edgeGroups.isEmpty() || !this.nodeGroups.isEmpty()) {
            throw new RuntimeException("Can't copy if the destinations aren't empty");
        }
        Iterator<List<Integer>> it = this.edgeGroupsMaster.iterator();
        while (it.hasNext()) {
            this.edgeGroups.add(it.next());
        }
        Iterator<List<Node<LabelType>>> it2 = this.nodeGroupsMaster.iterator();
        while (it2.hasNext()) {
            this.nodeGroups.add(it2.next());
        }
    }

    abstract void initMessages(Pair<Integer, Integer> pair);

    @Override // gov.sandia.cognition.graph.inference.EnergyFunctionSolver
    public void init(EnergyFunction<LabelType> energyFunction) {
        this.nodes = new ArrayList(energyFunction.numNodes());
        for (int i = 0; i < energyFunction.numNodes(); i++) {
            this.nodes.add(new Node<>(i, energyFunction.getPossibleLabels(i)));
        }
        for (int i2 = 0; i2 < energyFunction.numEdges(); i2++) {
            initMessages(energyFunction.getEdge(i2));
        }
        Iterator<Node<LabelType>> it = this.nodes.iterator();
        while (it.hasNext()) {
            it.next().resetToOne();
        }
        this.fn = energyFunction;
        this.edgeGroupsMaster = new ArrayList();
        int i3 = this.numThreads * 10;
        int numEdges = energyFunction.numEdges() / i3;
        int i4 = 0;
        for (int i5 = 0; i5 < i3 - 1; i5++) {
            ArrayList arrayList = new ArrayList(numEdges);
            for (int i6 = 0; i6 < numEdges; i6++) {
                arrayList.add(Integer.valueOf(i6 + i4));
            }
            this.edgeGroupsMaster.add(arrayList);
            i4 += numEdges;
        }
        ArrayList arrayList2 = new ArrayList(energyFunction.numEdges() - i4);
        for (int i7 = i4; i7 < energyFunction.numEdges(); i7++) {
            arrayList2.add(Integer.valueOf(i7));
        }
        this.edgeGroupsMaster.add(arrayList2);
        this.nodeGroupsMaster = new ArrayList();
        int numNodes = energyFunction.numNodes() / i3;
        ArrayList arrayList3 = new ArrayList(numNodes);
        Iterator<Node<LabelType>> it2 = this.nodes.iterator();
        while (it2.hasNext()) {
            arrayList3.add(it2.next());
            if (arrayList3.size() == numNodes) {
                this.nodeGroupsMaster.add(arrayList3);
                arrayList3 = new ArrayList(numNodes);
            }
        }
        if (!arrayList3.isEmpty()) {
            this.nodeGroupsMaster.add(arrayList3);
        }
        this.nodeGroups = new ConcurrentLinkedQueue<>();
        this.edgeGroups = new ConcurrentLinkedQueue<>();
    }

    @Override // gov.sandia.cognition.graph.inference.EnergyFunctionSolver
    public double getBelief(int i, int i2) {
        return this.nodes.get(i).getBelief(i2);
    }

    static {
        $assertionsDisabled = !SumProductInferencingAlgorithm.class.desiredAssertionStatus();
    }
}
