package org.neo4j.gds.embeddings.node2vec;

import java.util.stream.LongStream;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;

/* JADX INFO: Access modifiers changed from: package-private */
@ValueClass
/* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/RandomWalkProbabilities.class */
public interface RandomWalkProbabilities {

    /* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/RandomWalkProbabilities$Builder.class */
    public static class Builder {
        private final long nodeCount;
        private final int concurrency;
        private final double positiveSamplingFactor;
        private final double negativeSamplingExponent;
        private final HugeLongArray nodeFrequencies;
        private final MutableLong sampleCount = new MutableLong(0);

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder(long j, double d, double d2, int i) {
            this.nodeCount = j;
            this.concurrency = i;
            this.positiveSamplingFactor = d;
            this.negativeSamplingExponent = d2;
            this.nodeFrequencies = HugeLongArray.newArray(j);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder registerWalk(long[] jArr) {
            for (long j : jArr) {
                this.nodeFrequencies.addTo(j, 1L);
            }
            this.sampleCount.add(jArr.length);
            return this;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public RandomWalkProbabilities build() {
            HugeDoubleArray computePositiveSamplingProbabilities = computePositiveSamplingProbabilities();
            return ImmutableRandomWalkProbabilities.builder().nodeFrequencies(this.nodeFrequencies).positiveSamplingProbabilities(computePositiveSamplingProbabilities).negativeSamplingDistribution(computeNegativeSamplingDistribution()).sampleCount(this.sampleCount.getValue().longValue()).build();
        }

        private HugeDoubleArray computePositiveSamplingProbabilities() {
            HugeDoubleArray newArray = HugeDoubleArray.newArray(this.nodeCount);
            Long value = this.sampleCount.getValue();
            ParallelUtil.parallelStreamConsume(LongStream.range(0L, this.nodeCount), this.concurrency, TerminationFlag.RUNNING_TRUE, longStream -> {
                longStream.forEach(j -> {
                    double longValue = this.nodeFrequencies.get(j) / value.longValue();
                    newArray.set(j, (Math.sqrt(longValue / this.positiveSamplingFactor) + 1.0d) * (this.positiveSamplingFactor / longValue));
                });
            });
            return newArray;
        }

        private HugeLongArray computeNegativeSamplingDistribution() {
            HugeLongArray newArray = HugeLongArray.newArray(this.nodeCount);
            long j = 0;
            long j2 = 0;
            while (true) {
                long j3 = j2;
                if (j3 >= this.nodeCount) {
                    return newArray;
                }
                j = Math.addExact((long) (j + Math.pow(this.nodeFrequencies.get(j3), this.negativeSamplingExponent)), (long) Math.pow(this.nodeFrequencies.get(j3), this.negativeSamplingExponent));
                newArray.set(j3, j);
                j2 = j3 + 1;
            }
        }
    }

    HugeLongArray nodeFrequencies();

    HugeDoubleArray positiveSamplingProbabilities();

    HugeLongArray negativeSamplingDistribution();

    long sampleCount();

    static MemoryEstimation memoryEstimation() {
        return MemoryEstimations.builder(RandomWalkProbabilities.class.getSimpleName()).perNode("node frequencies", HugeLongArray::memoryEstimation).perNode("positive sampling probabilities", HugeDoubleArray::memoryEstimation).perNode("negative sampling distribution", HugeLongArray::memoryEstimation).build();
    }
}
