/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import com.carrotsearch.hppc.LongHashSet;
import com.carrotsearch.hppc.predicates.LongPredicate;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.LongConsumer;
import java.util.function.Supplier;
import java.util.stream.LongStream;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.queue.BoundedLongLongPriorityQueue;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.linkmodels.ExhaustiveLinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LPNodeFilter;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionSimilarityComputer;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.CloseableThreadLocal;

public class ExhaustiveLinkPrediction
extends LinkPrediction {
    private final int topN;
    private final double threshold;
    private final TerminationFlag terminationFlag;

    public ExhaustiveLinkPrediction(Classifier classifier, LinkFeatureExtractor linkFeatureExtractor, Graph graph, LPNodeFilter sourceNodeFilter, LPNodeFilter targetNodeFilter, Concurrency concurrency, int topN, double threshold, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        super(classifier, linkFeatureExtractor, graph, sourceNodeFilter, targetNodeFilter, concurrency, progressTracker);
        this.topN = topN;
        this.threshold = threshold;
        this.terminationFlag = terminationFlag;
    }

    public static MemoryEstimation estimate(LinkPredictionPredictPipelineBaseConfig config, int linkFeatureDimension) {
        return MemoryEstimations.builder((String)ExhaustiveLinkPrediction.class.getSimpleName()).add("Priority queue", BoundedLongLongPriorityQueue.memoryEstimation((int)config.topN().orElseThrow())).perGraphDimension("Predict links operation", (dim, threads) -> MemoryRange.of((long)(Estimate.sizeOfDoubleArray((long)linkFeatureDimension) + Estimate.sizeOfLongHashSet((long)dim.averageDegree()))).times((long)threads.value())).build();
    }

    ExhaustiveLinkPredictionResult predictLinks(LinkPredictionSimilarityComputer linkPredictionSimilarityComputer) {
        this.progressTracker.setSteps(this.graph.nodeCount());
        BoundedLongLongPriorityQueue predictionQueue = BoundedLongLongPriorityQueue.max((int)this.topN);
        LongAdder linksConsidered = new LongAdder();
        Supplier<LongConsumer> linkPredictorSupplier = () -> new LinkPredictionScoreByIdsConsumer(this.graph, this.sourceNodeFilter::test, this.targetNodeFilter::test, linkPredictionSimilarityComputer, predictionQueue, this.progressTracker, linksConsidered);
        try (CloseableThreadLocal localLinkPredictor = CloseableThreadLocal.withInitial(linkPredictorSupplier);){
            ParallelUtil.parallelForEachNode((long)this.graph.nodeCount(), (Concurrency)this.concurrency, (TerminationFlag)this.terminationFlag, nodeId -> ((LongConsumer)localLinkPredictor.get()).accept(nodeId));
        }
        return new ExhaustiveLinkPredictionResult(predictionQueue, linksConsidered.longValue());
    }

    final class LinkPredictionScoreByIdsConsumer
    implements LongConsumer {
        private final Graph graph;
        private final LongPredicate sourceNodeFilter;
        private final LongPredicate targetNodeFilter;
        private final LinkPredictionSimilarityComputer linkPredictionSimilarityComputer;
        private final BoundedLongLongPriorityQueue predictionQueue;
        private final ProgressTracker progressTracker;
        private final LongAdder linksConsidered;

        LinkPredictionScoreByIdsConsumer(Graph graph, LongPredicate sourceNodeFilter, LongPredicate targetNodeFilter, LinkPredictionSimilarityComputer linkPredictionSimilarityComputer, BoundedLongLongPriorityQueue predictionQueue, ProgressTracker progressTracker, LongAdder linksConsidered) {
            this.graph = graph.concurrentCopy();
            this.sourceNodeFilter = sourceNodeFilter;
            this.targetNodeFilter = targetNodeFilter;
            this.linkPredictionSimilarityComputer = linkPredictionSimilarityComputer;
            this.predictionQueue = predictionQueue;
            this.progressTracker = progressTracker;
            this.linksConsidered = linksConsidered;
        }

        @Override
        public void accept(long sourceId) {
            if (this.sourceNodeFilter.apply(sourceId)) {
                this.predictLinksFromNode(sourceId, this.targetNodeFilter);
            } else if (this.targetNodeFilter.apply(sourceId)) {
                this.predictLinksFromNode(sourceId, this.sourceNodeFilter);
            }
            this.progressTracker.logSteps(1L);
        }

        private LongHashSet largerValidNeighbors(long sourceId, LongPredicate targetNodeFilter) {
            LongHashSet neighbors = new LongHashSet();
            this.graph.forEachRelationship(sourceId, (src, trg) -> {
                if (src < trg && targetNodeFilter.apply(trg)) {
                    neighbors.add(trg);
                }
                return true;
            });
            return neighbors;
        }

        private void predictLinksFromNode(long sourceId, LongPredicate nodeFilter) {
            LongHashSet largerNeighbors = this.largerValidNeighbors(sourceId, nodeFilter);
            long smallestTarget = sourceId + 1L;
            LongStream.range(smallestTarget, this.graph.nodeCount()).forEach(targetId -> {
                if (largerNeighbors.contains(targetId)) {
                    return;
                }
                if (nodeFilter.apply(targetId)) {
                    double probability = this.linkPredictionSimilarityComputer.similarity(sourceId, targetId);
                    this.linksConsidered.increment();
                    if (probability < ExhaustiveLinkPrediction.this.threshold) {
                        return;
                    }
                    BoundedLongLongPriorityQueue boundedLongLongPriorityQueue = this.predictionQueue;
                    synchronized (boundedLongLongPriorityQueue) {
                        this.predictionQueue.offer(sourceId, targetId, probability);
                    }
                }
            });
        }
    }
}

