package org.neo4j.gds.embeddings.graphsage;

import java.util.ArrayList;
import java.util.List;
import java.util.OptionalLong;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.utils.queue.BoundedLongPriorityQueue;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/NeighborhoodSampler.class */
public class NeighborhoodSampler {
    private final double beta = 1.0d;
    private final Random random;
    private long randomSeed;

    /* JADX INFO: Access modifiers changed from: package-private */
    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/NeighborhoodSampler$MinMax.class */
    public interface MinMax {
        double min();

        double max();

        static MinMax of(double d, double d2) {
            return ImmutableMinMax.of(d, d2);
        }
    }

    public NeighborhoodSampler(long j) {
        this.randomSeed = j;
        this.random = new Random(j);
    }

    public List<Long> sample(Graph graph, long j, long j2) {
        AtomicLong atomicLong = new AtomicLong(j2);
        AtomicLong atomicLong2 = new AtomicLong(graph.degree(j));
        ArrayList arrayList = new ArrayList();
        MinMax minMax = minMax(graph, j);
        double min = minMax.min();
        double max = minMax.max();
        if (min == max) {
            graph.concurrentCopy().forEachRelationship(j, (j3, j4) -> {
                if (atomicLong.get() == 0 || atomicLong2.get() == 0) {
                    return false;
                }
                if (atomicLong2.getAndDecrement() * randomDouble(j3, j4, graph.nodeCount()) > atomicLong.get()) {
                    return true;
                }
                arrayList.add(Long.valueOf(j4));
                atomicLong.decrementAndGet();
                return true;
            });
        } else {
            graph.concurrentCopy().forEachRelationship(j, 1.0d, (j5, j6, d) -> {
                if (atomicLong.get() == 0 || atomicLong2.get() == 0) {
                    return false;
                }
                if (atomicLong2.getAndDecrement() * (1.0d - Math.pow((d - min) / (max - min), 1.0d)) > atomicLong.get()) {
                    return true;
                }
                arrayList.add(Long.valueOf(j6));
                atomicLong.decrementAndGet();
                return true;
            });
        }
        graph.concurrentCopy().forEachRelationship(j, 1.0d, (j7, j8, d2) -> {
            if (atomicLong.get() == 0 || atomicLong2.get() == 0) {
                return false;
            }
            if (atomicLong2.getAndDecrement() * (min == max ? randomDouble(j7, j8, graph.nodeCount()) : 1.0d - Math.pow((d2 - min) / (max - min), 1.0d)) > atomicLong.get()) {
                return true;
            }
            arrayList.add(Long.valueOf(j8));
            atomicLong.decrementAndGet();
            return true;
        });
        return arrayList;
    }

    private double randomDouble(long j, long j2, long j3) {
        this.random.setSeed(this.randomSeed + j + (j3 * j2));
        return this.random.nextDouble();
    }

    public long randomState() {
        return this.randomSeed;
    }

    public void generateNewRandomState() {
        this.randomSeed = ThreadLocalRandom.current().nextLong();
    }

    public OptionalLong sampleOne(Graph graph, long j) {
        List<Long> sample = sample(graph, j, 1L);
        return sample.size() < 1 ? OptionalLong.empty() : OptionalLong.of(sample.get(0).longValue());
    }

    private MinMax minMax(Graph graph, long j) {
        BoundedLongPriorityQueue max = BoundedLongPriorityQueue.max(1);
        BoundedLongPriorityQueue min = BoundedLongPriorityQueue.min(1);
        if (!graph.hasRelationshipProperty()) {
            return MinMax.of(0.0d, 0.0d);
        }
        graph.concurrentCopy().forEachRelationship(j, 1.0d, (j2, j3, d) -> {
            max.offer(j3, d);
            min.offer(j3, d);
            return true;
        });
        return MinMax.of(min.priorities().max().orElse(0.0d), max.priorities().min().orElse(0.0d));
    }
}
