package org.neo4j.gds.beta.pregel.triangleCount;

import com.carrotsearch.hppc.LongHashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.function.LongConsumer;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.nodeproperties.ValueType;
import org.neo4j.gds.beta.pregel.Messages;
import org.neo4j.gds.beta.pregel.PregelComputation;
import org.neo4j.gds.beta.pregel.PregelSchema;
import org.neo4j.gds.beta.pregel.Reducer;
import org.neo4j.gds.beta.pregel.annotation.GDSMode;
import org.neo4j.gds.beta.pregel.annotation.PregelProcedure;
import org.neo4j.gds.beta.pregel.context.ComputeContext;

@PregelProcedure(name = "example.pregel.triangleCount", modes = {GDSMode.STREAM})
/* loaded from: input_file:org/neo4j/gds/beta/pregel/triangleCount/TriangleCountPregel.class */
public class TriangleCountPregel implements PregelComputation<TriangleCountPregelConfig> {
    public static final String TRIANGLE_COUNT = "TRIANGLES";

    /* loaded from: input_file:org/neo4j/gds/beta/pregel/triangleCount/TriangleCountPregel$Phase.class */
    enum Phase {
        MERGE_NEIGHBORS(1),
        COUNT_TRIANGLES(2);

        final long step;

        Phase(int i) {
            this.step = i;
        }
    }

    public PregelSchema schema(TriangleCountPregelConfig triangleCountPregelConfig) {
        return new PregelSchema.Builder().add(TRIANGLE_COUNT, ValueType.LONG).build();
    }

    public void compute(ComputeContext<TriangleCountPregelConfig> computeContext, Messages messages) {
        if (computeContext.isInitialSuperstep()) {
            computeContext.setNodeValue(TRIANGLE_COUNT, 0L);
            return;
        }
        if (computeContext.superstep() == Phase.MERGE_NEIGHBORS.step) {
            LongHashSet longHashSet = new LongHashSet(computeContext.degree());
            Objects.requireNonNull(longHashSet);
            computeContext.forEachDistinctNeighbor(longHashSet::add);
            long nodeId = computeContext.nodeId();
            MutableLong mutableLong = new MutableLong();
            longHashSet.forEach(j -> {
                if (j > nodeId) {
                    LongConsumer longConsumer = j -> {
                        if (j <= j || !longHashSet.contains(j)) {
                            return;
                        }
                        mutableLong.increment();
                        computeContext.sendTo(j, 1.0d);
                        computeContext.sendTo(j, 1.0d);
                    };
                    if (computeContext.isMultiGraph()) {
                        computeContext.forEachDistinctNeighbor(j, longConsumer);
                    } else {
                        computeContext.forEachNeighbor(j, longConsumer);
                    }
                }
            });
            computeContext.setNodeValue(TRIANGLE_COUNT, mutableLong.longValue());
            return;
        }
        if (computeContext.superstep() == Phase.COUNT_TRIANGLES.step) {
            long longNodeValue = computeContext.longNodeValue(TRIANGLE_COUNT);
            if (!messages.isEmpty()) {
                longNodeValue = (long) (longNodeValue + messages.doubleIterator().nextDouble());
            }
            computeContext.setNodeValue(TRIANGLE_COUNT, longNodeValue);
            computeContext.voteToHalt();
        }
    }

    public Optional<Reducer> reducer() {
        return Optional.of(new Reducer.Count());
    }
}
