package org.neo4j.gds.ml.negativeSampling;

import com.carrotsearch.hppc.predicates.LongPredicate;
import java.util.HashSet;
import java.util.Optional;
import java.util.SplittableRandom;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;

/* loaded from: input_file:org/neo4j/gds/ml/negativeSampling/RandomNegativeSampler.class */
public class RandomNegativeSampler implements NegativeSampler {
    private static final int MAX_RETRIES = 20;
    private final SplittableRandom rng;
    private final Graph graph;
    private final long testSampleCount;
    private final long trainSampleCount;
    private final IdMap validSourceNodes;
    private final IdMap validTargetNodes;

    public RandomNegativeSampler(Graph graph, long j, long j2, IdMap idMap, IdMap idMap2, Optional<Long> optional) {
        this.graph = graph;
        this.testSampleCount = j;
        this.trainSampleCount = j2;
        this.validSourceNodes = idMap;
        this.validTargetNodes = idMap2;
        this.rng = (SplittableRandom) optional.map((v1) -> {
            return new SplittableRandom(v1);
        }).orElseGet(SplittableRandom::new);
    }

    @Override // org.neo4j.gds.ml.negativeSampling.NegativeSampler
    public void produceNegativeSamples(RelationshipsBuilder relationshipsBuilder, RelationshipsBuilder relationshipsBuilder2) {
        MutableLong mutableLong = new MutableLong(this.testSampleCount);
        MutableLong mutableLong2 = new MutableLong(this.trainSampleCount);
        MutableLong mutableLong3 = new MutableLong(this.validSourceNodes.nodeCount());
        LongPredicate longPredicate = j -> {
            return this.validSourceNodes.contains(this.graph.toOriginalNodeId(j));
        };
        LongPredicate longPredicate2 = j2 -> {
            return this.validTargetNodes.contains(this.graph.toOriginalNodeId(j2));
        };
        this.graph.forEachNode(j3 -> {
            if (!longPredicate.apply(j3)) {
                return true;
            }
            int degree = this.graph.degree(j3);
            long samplesPerNode = samplesPerNode((this.graph.nodeCount() - 1) - degree, mutableLong.longValue() + mutableLong2.longValue(), mutableLong3.getAndDecrement());
            HashSet hashSet = new HashSet(degree);
            this.graph.forEachRelationship(j3, (j3, j4) -> {
                hashSet.add(Long.valueOf(j4));
                return true;
            });
            int i = MAX_RETRIES;
            int i2 = 0;
            while (i2 < samplesPerNode) {
                long randomNodeId = randomNodeId(this.graph);
                if (!longPredicate2.apply(randomNodeId) || hashSet.contains(Long.valueOf(randomNodeId)) || randomNodeId == j3) {
                    int i3 = i;
                    i--;
                    if (i3 > 0) {
                        i2--;
                    }
                } else if (sample(mutableLong.doubleValue() / (mutableLong.doubleValue() + mutableLong2.doubleValue()))) {
                    mutableLong.decrement();
                    relationshipsBuilder.addFromInternal(this.graph.toRootNodeId(j3), this.graph.toRootNodeId(randomNodeId), NegativeSampler.NEGATIVE);
                } else {
                    mutableLong2.decrement();
                    relationshipsBuilder2.addFromInternal(this.graph.toRootNodeId(j3), this.graph.toRootNodeId(randomNodeId), NegativeSampler.NEGATIVE);
                }
                i2++;
            }
            return true;
        });
    }

    private long randomNodeId(Graph graph) {
        return Math.abs(this.rng.nextLong() % graph.nodeCount());
    }

    private long samplesPerNode(long j, double d, long j2) {
        double d2 = d / j2;
        long j3 = (long) d2;
        return Math.min(j, j3 + (sample(d2 - ((double) j3)) ? 1 : 0));
    }

    private boolean sample(double d) {
        return this.rng.nextDouble() < d;
    }
}
