package org.neo4j.graphalgo.beta.pregel;

import com.carrotsearch.hppc.BitSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.stream.LongStream;
import org.jctools.queues.MpscLinkedQueue;
import org.neo4j.collection.primitive.PrimitiveLongCollections;
import org.neo4j.collection.primitive.PrimitiveLongIterable;
import org.neo4j.collection.primitive.PrimitiveLongIterator;
import org.neo4j.graphalgo.api.Degrees;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeProperties;
import org.neo4j.graphalgo.api.RelationshipIterator;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.utils.LazyBatchCollection;
import org.neo4j.graphalgo.core.utils.LazyMappingCollection;
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeDoubleArray;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;

/* loaded from: input_file:org/neo4j/graphalgo/beta/pregel/Pregel.class */
public final class Pregel {
    private static final Double TERMINATION_SYMBOL = Double.valueOf(Double.NaN);
    private final PregelConfig config;
    private final PregelComputation computation;
    private final Graph graph;
    private final HugeDoubleArray nodeValues;
    private final HugeObjectArray<MpscLinkedQueue<Double>> messageQueues;
    private final int batchSize;
    private final int concurrency;
    private final ExecutorService executor;
    private int iterations;

    /* loaded from: input_file:org/neo4j/graphalgo/beta/pregel/Pregel$ComputeStep.class */
    public static final class ComputeStep implements Runnable {
        private final int iteration;
        private final PregelComputation computation;
        private final PregelContext pregelContext;
        private final BitSet senderBits;
        private final BitSet receiverBits;
        private final BitSet voteBits;
        private final PrimitiveLongIterable nodeBatch;
        private final Degrees degrees;
        private final HugeDoubleArray nodeValues;
        private final HugeObjectArray<? extends Queue<Double>> messageQueues;
        private final RelationshipIterator relationshipIterator;

        private ComputeStep(PregelComputation pregelComputation, PregelConfig pregelConfig, long j, int i, PrimitiveLongIterable primitiveLongIterable, Degrees degrees, HugeDoubleArray hugeDoubleArray, BitSet bitSet, BitSet bitSet2, HugeObjectArray<? extends Queue<Double>> hugeObjectArray, RelationshipIterator relationshipIterator) {
            this.iteration = i;
            this.computation = pregelComputation;
            this.senderBits = new BitSet(j);
            this.receiverBits = bitSet;
            this.voteBits = bitSet2;
            this.nodeBatch = primitiveLongIterable;
            this.degrees = degrees;
            this.nodeValues = hugeDoubleArray;
            this.messageQueues = hugeObjectArray;
            this.relationshipIterator = relationshipIterator.concurrentCopy();
            this.pregelContext = new PregelContext(this, pregelConfig);
        }

        @Override // java.lang.Runnable
        public void run() {
            PrimitiveLongIterator it = this.nodeBatch.iterator();
            while (it.hasNext()) {
                long next = it.next();
                if (this.receiverBits.get(next) || !this.voteBits.get(next)) {
                    this.voteBits.clear(next);
                    this.computation.compute(this.pregelContext, next, receiveMessages(next));
                }
            }
        }

        BitSet getSenders() {
            return this.senderBits;
        }

        BitSet getVotes() {
            return this.voteBits;
        }

        public int getIteration() {
            return this.iteration;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public int getDegree(long j) {
            return this.degrees.degree(j);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double getNodeValue(long j) {
            return this.nodeValues.get(j);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void setNodeValue(long j, double d) {
            this.nodeValues.set(j, d);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void voteToHalt(long j) {
            this.voteBits.set(j);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void sendMessages(long j, double d) {
            this.relationshipIterator.forEachRelationship(j, (j2, j3) -> {
                ((Queue) this.messageQueues.get(j3)).add(Double.valueOf(d));
                this.senderBits.set(j3);
                return true;
            });
        }

        private Queue<Double> receiveMessages(long j) {
            if (this.receiverBits.get(j)) {
                return (Queue) this.messageQueues.get(j);
            }
            return null;
        }
    }

    public static Pregel withDefaultNodeValues(Graph graph, PregelConfig pregelConfig, PregelComputation pregelComputation, int i, int i2, ExecutorService executorService, AllocationTracker allocationTracker) {
        double initialNodeValue = pregelConfig.getInitialNodeValue();
        HugeDoubleArray newArray = HugeDoubleArray.newArray(graph.nodeCount(), allocationTracker);
        ParallelUtil.parallelStreamConsume(LongStream.range(0L, graph.nodeCount()), i2, longStream -> {
            longStream.forEach(j -> {
                newArray.set(j, initialNodeValue);
            });
        });
        return new Pregel(graph, pregelConfig, pregelComputation, newArray, i, i2, executorService, allocationTracker);
    }

    public static Pregel withInitialNodeValues(Graph graph, PregelConfig pregelConfig, PregelComputation pregelComputation, NodeProperties nodeProperties, int i, int i2, ExecutorService executorService, AllocationTracker allocationTracker) {
        HugeDoubleArray newArray = HugeDoubleArray.newArray(graph.nodeCount(), allocationTracker);
        ParallelUtil.parallelStreamConsume(LongStream.range(0L, graph.nodeCount()), i2, longStream -> {
            longStream.forEach(j -> {
                newArray.set(j, nodeProperties.nodeProperty(j));
            });
        });
        return new Pregel(graph, pregelConfig, pregelComputation, newArray, i, i2, executorService, allocationTracker);
    }

    private Pregel(Graph graph, PregelConfig pregelConfig, PregelComputation pregelComputation, HugeDoubleArray hugeDoubleArray, int i, int i2, ExecutorService executorService, AllocationTracker allocationTracker) {
        this.graph = graph;
        this.config = pregelConfig;
        this.computation = pregelComputation;
        this.nodeValues = hugeDoubleArray;
        this.batchSize = i;
        this.concurrency = i2;
        this.executor = executorService;
        this.messageQueues = initLinkedQueues(graph, allocationTracker);
    }

    public HugeDoubleArray run(int i) {
        this.iterations = 0;
        boolean z = false;
        BitSet bitSet = new BitSet(this.graph.nodeCount());
        BitSet bitSet2 = new BitSet(this.graph.nodeCount());
        Collection<PrimitiveLongIterable> of = LazyBatchCollection.of(this.graph.nodeCount(), this.batchSize, (j, j2) -> {
            return () -> {
                return PrimitiveLongCollections.range(j, (j + j2) - 1);
            };
        });
        while (this.iterations < i && !z) {
            int i2 = this.iterations;
            this.iterations = i2 + 1;
            List<ComputeStep> runComputeSteps = runComputeSteps(of, i2, bitSet, bitSet2);
            bitSet = unionBitSets(runComputeSteps, (v0) -> {
                return v0.getSenders();
            });
            bitSet2 = unionBitSets(runComputeSteps, (v0) -> {
                return v0.getVotes();
            });
            if (bitSet.nextSetBit(0) == -1) {
                z = true;
            }
        }
        return this.nodeValues;
    }

    public int getIterations() {
        return this.iterations;
    }

    private BitSet unionBitSets(Collection<ComputeStep> collection, Function<ComputeStep, BitSet> function) {
        return (BitSet) ParallelUtil.parallelStream(collection.stream(), this.concurrency, stream -> {
            return (BitSet) stream.map(function).reduce((bitSet, bitSet2) -> {
                bitSet.union(bitSet2);
                return bitSet;
            }).orElseGet(BitSet::new);
        });
    }

    private List<ComputeStep> runComputeSteps(Collection<PrimitiveLongIterable> collection, int i, BitSet bitSet, BitSet bitSet2) {
        ArrayList arrayList = new ArrayList(collection.size());
        if (!this.config.isAsynchronous() && i > 0) {
            ParallelUtil.parallelStreamConsume(LongStream.range(0L, this.graph.nodeCount()), this.concurrency, longStream -> {
                longStream.forEach(j -> {
                    if (bitSet.get(j)) {
                        ((MpscLinkedQueue) this.messageQueues.get(j)).add(TERMINATION_SYMBOL);
                    }
                });
            });
        }
        ParallelUtil.runWithConcurrency(this.concurrency, LazyMappingCollection.of(collection, primitiveLongIterable -> {
            ComputeStep computeStep = new ComputeStep(this.computation, this.config, this.graph.nodeCount(), i, primitiveLongIterable, this.graph, this.nodeValues, bitSet, bitSet2, this.messageQueues, this.graph);
            arrayList.add(computeStep);
            return computeStep;
        }), this.executor);
        return arrayList;
    }

    private HugeObjectArray<MpscLinkedQueue<Double>> initLinkedQueues(Graph graph, AllocationTracker allocationTracker) {
        HugeObjectArray<MpscLinkedQueue<Double>> newArray = HugeObjectArray.newArray(MpscLinkedQueue.newMpscLinkedQueue().getClass(), graph.nodeCount(), allocationTracker);
        ParallelUtil.parallelStreamConsume(LongStream.range(0L, graph.nodeCount()), this.concurrency, longStream -> {
            longStream.forEach(j -> {
                newArray.set(j, MpscLinkedQueue.newMpscLinkedQueue());
            });
        });
        return newArray;
    }
}
