package org.neo4j.graphalgo.beta.pregel;

import java.util.function.LongConsumer;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.beta.pregel.Messages;
import org.neo4j.graphalgo.beta.pregel.Messages.MessageIterator;
import org.neo4j.graphalgo.beta.pregel.PregelConfig;
import org.neo4j.graphalgo.beta.pregel.context.ComputeContext;
import org.neo4j.graphalgo.beta.pregel.context.InitContext;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.graphalgo.core.utils.partition.Partition;

/* loaded from: input_file:org/neo4j/graphalgo/beta/pregel/ComputeStep.class */
public interface ComputeStep<CONFIG extends PregelConfig, ITERATOR extends Messages.MessageIterator> {
    Graph graph();

    HugeAtomicBitSet voteBits();

    PregelComputation<CONFIG> computation();

    NodeValue nodeValue();

    Messenger<ITERATOR> messenger();

    Partition nodeBatch();

    InitContext<CONFIG> initContext();

    ComputeContext<CONFIG> computeContext();

    ProgressLogger progressLogger();

    int iteration();

    default boolean isMultiGraph() {
        return graph().isMultiGraph();
    }

    default long nodeCount() {
        return graph().nodeCount();
    }

    default long relationshipCount() {
        return graph().relationshipCount();
    }

    default int degree(long j) {
        return graph().degree(j);
    }

    default void voteToHalt(long j) {
        voteBits().set(j);
    }

    void sendTo(long j, double d);

    default void computeBatch() {
        Messenger<ITERATOR> messenger = messenger();
        ITERATOR messageIterator = messenger.messageIterator();
        Messages messages = new Messages(messageIterator);
        Partition nodeBatch = nodeBatch();
        PregelComputation<CONFIG> computation = computation();
        InitContext<CONFIG> initContext = initContext();
        ComputeContext<CONFIG> computeContext = computeContext();
        HugeAtomicBitSet voteBits = voteBits();
        nodeBatch.consume(j -> {
            if (computeContext.isInitialSuperstep()) {
                initContext.setNodeId(j);
                computation.init(initContext);
            }
            messenger.initMessageIterator(messageIterator, j, computeContext.isInitialSuperstep());
            if (messages.isEmpty() && voteBits.get(j)) {
                return;
            }
            voteBits.clear(j);
            computeContext.setNodeId(j);
            computation.compute(computeContext, messages);
        });
        progressLogger().logProgress(nodeBatch.nodeCount());
    }

    default void sendToNeighbors(long j, double d) {
        graph().forEachRelationship(j, (j2, j3) -> {
            sendTo(j3, d);
            return true;
        });
    }

    default void sendToNeighborsWeighted(long j, double d) {
        graph().forEachRelationship(j, 1.0d, (j2, j3, d2) -> {
            sendTo(j3, computation().applyRelationshipWeight(d, d2));
            return true;
        });
    }

    default void forEachNeighbor(long j, LongConsumer longConsumer) {
        graph().forEachRelationship(j, (j2, j3) -> {
            longConsumer.accept(j3);
            return true;
        });
    }

    default void forEachDistinctNeighbor(long j, LongConsumer longConsumer) {
        MutableLong mutableLong = new MutableLong(-1L);
        graph().forEachRelationship(j, (j2, j3) -> {
            if (mutableLong.longValue() == j3) {
                return true;
            }
            longConsumer.accept(j3);
            mutableLong.setValue(j3);
            return true;
        });
    }

    default double doubleNodeValue(String str, long j) {
        return nodeValue().doubleValue(str, j);
    }

    default long longNodeValue(String str, long j) {
        return nodeValue().longValue(str, j);
    }

    default long[] longArrayNodeValue(String str, long j) {
        return nodeValue().longArrayValue(str, j);
    }

    default double[] doubleArrayNodeValue(String str, long j) {
        return nodeValue().doubleArrayValue(str, j);
    }

    default void setNodeValue(String str, long j, double d) {
        nodeValue().set(str, j, d);
    }

    default void setNodeValue(String str, long j, long j2) {
        nodeValue().set(str, j, j2);
    }

    default void setNodeValue(String str, long j, long[] jArr) {
        nodeValue().set(str, j, jArr);
    }

    default void setNodeValue(String str, long j, double[] dArr) {
        nodeValue().set(str, j, dArr);
    }
}
