package org.neo4j.graphalgo.beta.pregel;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ExecutorService;
import java.util.stream.LongStream;
import org.immutables.value.Value;
import org.jctools.queues.MpscLinkedQueue;
import org.jetbrains.annotations.Nullable;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.Degrees;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.RelationshipIterator;
import org.neo4j.graphalgo.api.nodeproperties.ValueType;
import org.neo4j.graphalgo.beta.pregel.Messages;
import org.neo4j.graphalgo.beta.pregel.PregelConfig;
import org.neo4j.graphalgo.beta.pregel.context.ComputeContext;
import org.neo4j.graphalgo.beta.pregel.context.InitContext;
import org.neo4j.graphalgo.beta.pregel.context.MasterComputeContext;
import org.neo4j.graphalgo.config.GraphCreateConfig;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.loading.RelationshipsBatchBuffer;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimation;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimations;
import org.neo4j.graphalgo.core.utils.mem.MemoryUsage;
import org.neo4j.graphalgo.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.graphalgo.core.utils.paged.HugeDoubleArray;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;
import org.neo4j.graphalgo.core.utils.partition.Partition;
import org.neo4j.graphalgo.core.utils.partition.PartitionUtils;
import org.neo4j.graphalgo.utils.StringFormatting;

@Value.Style(builderVisibility = Value.Style.BuilderVisibility.PUBLIC, depluralize = true, deepImmutablesDetection = true)
/* loaded from: input_file:org/neo4j/graphalgo/beta/pregel/Pregel.class */
public final class Pregel<CONFIG extends PregelConfig> {
    private static final Double TERMINATION_SYMBOL = Double.valueOf(Double.NaN);
    private final CONFIG config;
    private final PregelComputation<CONFIG> computation;
    private final Graph graph;
    private final CompositeNodeValue nodeValues;
    private final HugeObjectArray<MpscLinkedQueue<Double>> messageQueues;
    private final int concurrency;
    private final ExecutorService executor;
    private final AllocationTracker tracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.neo4j.graphalgo.beta.pregel.Pregel$1, reason: invalid class name */
    /* loaded from: input_file:org/neo4j/graphalgo/beta/pregel/Pregel$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType = new int[ValueType.values().length];

        static {
            try {
                $SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[ValueType.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[ValueType.LONG.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[ValueType.LONG_ARRAY.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[ValueType.DOUBLE_ARRAY.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:org/neo4j/graphalgo/beta/pregel/Pregel$CompositeNodeValue.class */
    public static final class CompositeNodeValue {
        private final PregelSchema pregelSchema;
        private final Map<String, Object> properties;

        static CompositeNodeValue of(PregelSchema pregelSchema, long j, int i, AllocationTracker allocationTracker) {
            HashMap hashMap = new HashMap();
            pregelSchema.elements().forEach(element -> {
                switch (AnonymousClass1.$SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[element.propertyType().ordinal()]) {
                    case 1:
                        HugeDoubleArray newArray = HugeDoubleArray.newArray(j, allocationTracker);
                        ParallelUtil.parallelStreamConsume(LongStream.range(0L, j), i, longStream -> {
                            longStream.forEach(j2 -> {
                                newArray.set(j2, Double.NaN);
                            });
                        });
                        hashMap.put(element.propertyKey(), newArray);
                        return;
                    case RelationshipsBatchBuffer.RELATIONSHIP_REFERENCE_OFFSET /* 2 */:
                        HugeLongArray newArray2 = HugeLongArray.newArray(j, allocationTracker);
                        ParallelUtil.parallelStreamConsume(LongStream.range(0L, j), i, longStream2 -> {
                            longStream2.forEach(j2 -> {
                                newArray2.set(j2, Long.MIN_VALUE);
                            });
                        });
                        hashMap.put(element.propertyKey(), newArray2);
                        return;
                    case RelationshipsBatchBuffer.PROPERTIES_REFERENCE_OFFSET /* 3 */:
                        hashMap.put(element.propertyKey(), HugeObjectArray.newArray(long[].class, j, allocationTracker));
                        return;
                    case 4:
                        hashMap.put(element.propertyKey(), HugeObjectArray.newArray(double[].class, j, allocationTracker));
                        return;
                    default:
                        throw new IllegalArgumentException(StringFormatting.formatWithLocale("Unsupported value type: %s", new Object[]{element.propertyType()}));
                }
            });
            return new CompositeNodeValue(pregelSchema, hashMap);
        }

        private CompositeNodeValue(PregelSchema pregelSchema, Map<String, Object> map) {
            this.pregelSchema = pregelSchema;
            this.properties = map;
        }

        public PregelSchema schema() {
            return this.pregelSchema;
        }

        public HugeDoubleArray doubleProperties(String str) {
            return (HugeDoubleArray) checkProperty(str, HugeDoubleArray.class);
        }

        public HugeLongArray longProperties(String str) {
            return (HugeLongArray) checkProperty(str, HugeLongArray.class);
        }

        public HugeObjectArray<long[]> longArrayProperties(String str) {
            return (HugeObjectArray) checkProperty(str, HugeObjectArray.class);
        }

        public HugeObjectArray<double[]> doubleArrayProperties(String str) {
            return (HugeObjectArray) checkProperty(str, HugeObjectArray.class);
        }

        public double doubleValue(String str, long j) {
            return doubleProperties(str).get(j);
        }

        public long longValue(String str, long j) {
            return longProperties(str).get(j);
        }

        public long[] longArrayValue(String str, long j) {
            return longArrayProperties(str).get(j);
        }

        public double[] doubleArrayValue(String str, long j) {
            return doubleArrayProperties(str).get(j);
        }

        public void set(String str, long j, double d) {
            doubleProperties(str).set(j, d);
        }

        public void set(String str, long j, long j2) {
            longProperties(str).set(j, j2);
        }

        public void set(String str, long j, long[] jArr) {
            longArrayProperties(str).set(j, jArr);
        }

        public void set(String str, long j, double[] dArr) {
            doubleArrayProperties(str).set(j, dArr);
        }

        private <PROPERTY> PROPERTY checkProperty(String str, Class<? extends PROPERTY> cls) {
            PROPERTY property = (PROPERTY) this.properties.get(str);
            if (property == null) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("Property with key %s does not exist. Available properties are: %s", new Object[]{str, this.properties.keySet()}));
            }
            if (cls.isAssignableFrom(property.getClass())) {
                return property;
            }
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Could not cast property %s of type %s into %s", new Object[]{str, property.getClass().getSimpleName(), cls.getSimpleName()}));
        }
    }

    /* loaded from: input_file:org/neo4j/graphalgo/beta/pregel/Pregel$ComputeStep.class */
    public static final class ComputeStep<CONFIG extends PregelConfig> implements Runnable {
        private final long nodeCount;
        private final long relationshipCount;
        private final boolean isAsync;
        private final boolean isMultiGraph;
        private final PregelComputation<CONFIG> computation;
        private final InitContext<CONFIG> initContext;
        private final ComputeContext<CONFIG> computeContext;
        private final Partition nodeBatch;
        private final Degrees degrees;
        private final CompositeNodeValue nodeValues;
        private final HugeObjectArray<? extends Queue<Double>> messageQueues;
        private final RelationshipIterator relationshipIterator;
        private int iteration;
        private HugeAtomicBitSet messageBits;
        private HugeAtomicBitSet prevMessageBits;
        private final HugeAtomicBitSet voteBits;

        private ComputeStep(Graph graph, PregelComputation<CONFIG> pregelComputation, CONFIG config, int i, Partition partition, CompositeNodeValue compositeNodeValue, HugeObjectArray<? extends Queue<Double>> hugeObjectArray, HugeAtomicBitSet hugeAtomicBitSet, RelationshipIterator relationshipIterator) {
            this.iteration = i;
            this.nodeCount = graph.nodeCount();
            this.relationshipCount = graph.relationshipCount();
            this.isAsync = config.isAsynchronous();
            this.computation = pregelComputation;
            this.voteBits = hugeAtomicBitSet;
            this.nodeBatch = partition;
            this.degrees = graph;
            this.isMultiGraph = graph.isMultiGraph();
            this.nodeValues = compositeNodeValue;
            this.messageQueues = hugeObjectArray;
            this.relationshipIterator = relationshipIterator.concurrentCopy();
            this.computeContext = new ComputeContext<>(this, config);
            this.initContext = new InitContext<>(this, config, graph);
        }

        void init(int i, HugeAtomicBitSet hugeAtomicBitSet, HugeAtomicBitSet hugeAtomicBitSet2) {
            this.iteration = i;
            this.messageBits = hugeAtomicBitSet;
            this.prevMessageBits = hugeAtomicBitSet2;
        }

        @Override // java.lang.Runnable
        public void run() {
            Messages.MessageIterator async = this.isAsync ? new Messages.MessageIterator.Async() : new Messages.MessageIterator.Sync();
            Messages messages = new Messages(async);
            long startNode = this.nodeBatch.startNode();
            long nodeCount = startNode + this.nodeBatch.nodeCount();
            long j = startNode;
            while (true) {
                long j2 = j;
                if (j2 >= nodeCount) {
                    return;
                }
                if (this.computeContext.isInitialSuperstep()) {
                    this.initContext.setNodeId(j2);
                    this.computation.init(this.initContext);
                }
                if (this.prevMessageBits.get(j2) || !this.voteBits.get(j2)) {
                    this.voteBits.clear(j2);
                    this.computeContext.setNodeId(j2);
                    async.init(receiveMessages(j2));
                    this.computation.compute(this.computeContext, messages);
                }
                j = j2 + 1;
            }
        }

        public int iteration() {
            return this.iteration;
        }

        public boolean isMultiGraph() {
            return this.isMultiGraph;
        }

        public long nodeCount() {
            return this.nodeCount;
        }

        public long relationshipCount() {
            return this.relationshipCount;
        }

        public int degree(long j) {
            return this.degrees.degree(j);
        }

        public void voteToHalt(long j) {
            this.voteBits.set(j);
        }

        public void sendToNeighbors(long j, double d) {
            this.relationshipIterator.forEachRelationship(j, (j2, j3) -> {
                sendTo(j3, d);
                return true;
            });
        }

        public LongStream getNeighbors(long j) {
            LongStream.Builder builder = LongStream.builder();
            this.relationshipIterator.forEachRelationship(j, (j2, j3) -> {
                builder.accept(j3);
                return true;
            });
            return builder.build();
        }

        public void sendTo(long j, double d) {
            this.messageQueues.get(j).add(Double.valueOf(d));
            this.messageBits.set(j);
        }

        public void sendToNeighborsWeighted(long j, double d) {
            this.relationshipIterator.forEachRelationship(j, 1.0d, (j2, j3, d2) -> {
                this.messageQueues.get(j3).add(Double.valueOf(this.computation.applyRelationshipWeight(d, d2)));
                this.messageBits.set(j3);
                return true;
            });
        }

        @Nullable
        private Queue<Double> receiveMessages(long j) {
            if (this.prevMessageBits.get(j)) {
                return this.messageQueues.get(j);
            }
            return null;
        }

        public double doubleNodeValue(String str, long j) {
            return this.nodeValues.doubleValue(str, j);
        }

        public long longNodeValue(String str, long j) {
            return this.nodeValues.longValue(str, j);
        }

        public long[] longArrayNodeValue(String str, long j) {
            return this.nodeValues.longArrayValue(str, j);
        }

        public double[] doubleArrayNodeValue(String str, long j) {
            return this.nodeValues.doubleArrayValue(str, j);
        }

        public void setNodeValue(String str, long j, double d) {
            this.nodeValues.set(str, j, d);
        }

        public void setNodeValue(String str, long j, long j2) {
            this.nodeValues.set(str, j, j2);
        }

        public void setNodeValue(String str, long j, long[] jArr) {
            this.nodeValues.set(str, j, jArr);
        }

        public void setNodeValue(String str, long j, double[] dArr) {
            this.nodeValues.set(str, j, dArr);
        }
    }

    @ValueClass
    /* loaded from: input_file:org/neo4j/graphalgo/beta/pregel/Pregel$PregelResult.class */
    public interface PregelResult {
        CompositeNodeValue nodeValues();

        int ranIterations();

        boolean didConverge();
    }

    public static <CONFIG extends PregelConfig> Pregel<CONFIG> create(Graph graph, CONFIG config, PregelComputation<CONFIG> pregelComputation, ExecutorService executorService, AllocationTracker allocationTracker) {
        ImmutablePregelConfig.copyOf(config);
        return new Pregel<>(graph, config, pregelComputation, CompositeNodeValue.of(pregelComputation.schema(), graph.nodeCount(), config.concurrency(), allocationTracker), executorService, allocationTracker);
    }

    public static MemoryEstimation memoryEstimation(PregelSchema pregelSchema) {
        return MemoryEstimations.builder((Class<?>) Pregel.class).perNode("message bits", MemoryUsage::sizeOfHugeAtomicBitset).perNode("previous message bits", MemoryUsage::sizeOfHugeAtomicBitset).perNode("vote bits", MemoryUsage::sizeOfHugeAtomicBitset).perThread("compute steps", MemoryEstimations.builder((Class<?>) ComputeStep.class).build()).add("message queues", MemoryEstimations.setup(GraphCreateConfig.IMPLICIT_GRAPH_NAME, (graphDimensions, i) -> {
            return MemoryEstimations.builder().fixed(HugeObjectArray.class.getSimpleName(), MemoryUsage.sizeOfInstance(HugeObjectArray.class)).perNode("node queue", MemoryEstimations.builder((Class<?>) MpscLinkedQueue.class).fixed("messages", graphDimensions.averageDegree() * 8).build()).build();
        })).add("composite node value", MemoryEstimations.setup(GraphCreateConfig.IMPLICIT_GRAPH_NAME, (graphDimensions2, i2) -> {
            MemoryEstimations.Builder builder = MemoryEstimations.builder();
            pregelSchema.elements().forEach(element -> {
                String formatWithLocale = StringFormatting.formatWithLocale("%s (%s)", new Object[]{element.propertyKey(), element.propertyType()});
                switch (AnonymousClass1.$SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[element.propertyType().ordinal()]) {
                    case 1:
                        builder.fixed(formatWithLocale, HugeDoubleArray.memoryEstimation(graphDimensions2.nodeCount()));
                        return;
                    case RelationshipsBatchBuffer.RELATIONSHIP_REFERENCE_OFFSET /* 2 */:
                        builder.fixed(formatWithLocale, HugeLongArray.memoryEstimation(graphDimensions2.nodeCount()));
                        return;
                    case RelationshipsBatchBuffer.PROPERTIES_REFERENCE_OFFSET /* 3 */:
                        builder.add(formatWithLocale, MemoryEstimations.builder().fixed(HugeObjectArray.class.getSimpleName(), MemoryUsage.sizeOfInstance(HugeObjectArray.class)).perNode("long[10]", j -> {
                            return j * MemoryUsage.sizeOfLongArray(10L);
                        }).build());
                        return;
                    case 4:
                        builder.add(formatWithLocale, MemoryEstimations.builder().fixed(HugeObjectArray.class.getSimpleName(), MemoryUsage.sizeOfInstance(HugeObjectArray.class)).perNode("double[10]", j2 -> {
                            return j2 * MemoryUsage.sizeOfDoubleArray(10L);
                        }).build());
                        return;
                    default:
                        builder.add(formatWithLocale, MemoryEstimations.empty());
                        return;
                }
            });
            return builder.build();
        })).build();
    }

    private Pregel(Graph graph, CONFIG config, PregelComputation<CONFIG> pregelComputation, CompositeNodeValue compositeNodeValue, ExecutorService executorService, AllocationTracker allocationTracker) {
        this.graph = graph;
        this.config = config;
        this.computation = pregelComputation;
        this.nodeValues = compositeNodeValue;
        this.concurrency = config.concurrency();
        this.executor = executorService;
        this.tracker = allocationTracker;
        this.messageQueues = initLinkedQueues(graph, allocationTracker);
    }

    public PregelResult run() {
        boolean z = false;
        HugeAtomicBitSet create = HugeAtomicBitSet.create(this.graph.nodeCount(), this.tracker);
        HugeAtomicBitSet create2 = HugeAtomicBitSet.create(this.graph.nodeCount(), this.tracker);
        HugeAtomicBitSet create3 = HugeAtomicBitSet.create(this.graph.nodeCount(), this.tracker);
        List<ComputeStep<CONFIG>> createComputeSteps = createComputeSteps(create3);
        int i = 0;
        while (true) {
            if (i < this.config.maxIterations()) {
                if (i > 0) {
                    create.clear();
                }
                Iterator<ComputeStep<CONFIG>> it = createComputeSteps.iterator();
                while (it.hasNext()) {
                    it.next().init(i, create, create2);
                }
                runComputeSteps(createComputeSteps, i, create2);
                runMasterComputeStep(i);
                if (create.isEmpty() && create3.allSet()) {
                    z = true;
                    break;
                }
                HugeAtomicBitSet hugeAtomicBitSet = create;
                create = create2;
                create2 = hugeAtomicBitSet;
                i++;
            } else {
                break;
            }
        }
        return ImmutablePregelResult.builder().nodeValues(this.nodeValues).didConverge(z).ranIterations(i).build();
    }

    public void release() {
        this.messageQueues.release();
    }

    private List<ComputeStep<CONFIG>> createComputeSteps(HugeAtomicBitSet hugeAtomicBitSet) {
        List<Partition> rangePartition = PartitionUtils.rangePartition(this.concurrency, this.graph.nodeCount());
        ArrayList arrayList = new ArrayList(this.concurrency);
        Iterator<Partition> it = rangePartition.iterator();
        while (it.hasNext()) {
            arrayList.add(new ComputeStep(this.graph, this.computation, this.config, 0, it.next(), this.nodeValues, this.messageQueues, hugeAtomicBitSet, this.graph));
        }
        return arrayList;
    }

    private void runComputeSteps(Collection<ComputeStep<CONFIG>> collection, int i, HugeAtomicBitSet hugeAtomicBitSet) {
        if (!this.config.isAsynchronous() && i > 0) {
            ParallelUtil.parallelStreamConsume(LongStream.range(0L, this.graph.nodeCount()), this.concurrency, longStream -> {
                longStream.forEach(j -> {
                    if (hugeAtomicBitSet.get(j)) {
                        this.messageQueues.get(j).add(TERMINATION_SYMBOL);
                    }
                });
            });
        }
        ParallelUtil.runWithConcurrency(this.concurrency, collection, this.executor);
    }

    private void runMasterComputeStep(int i) {
        this.computation.masterCompute(new MasterComputeContext<>(this.config, this.graph, i, this.nodeValues));
    }

    private HugeObjectArray<MpscLinkedQueue<Double>> initLinkedQueues(Graph graph, AllocationTracker allocationTracker) {
        HugeObjectArray<MpscLinkedQueue<Double>> newArray = HugeObjectArray.newArray(new MpscLinkedQueue().getClass(), graph.nodeCount(), allocationTracker);
        ParallelUtil.parallelStreamConsume(LongStream.range(0L, graph.nodeCount()), this.concurrency, longStream -> {
            longStream.forEach(j -> {
                newArray.set(j, new MpscLinkedQueue());
            });
        });
        return newArray;
    }
}
