/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.kge;

import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.DoubleArrayList;
import com.carrotsearch.hppc.predicates.LongLongPredicate;
import java.util.List;
import java.util.stream.BaseStream;
import java.util.stream.LongStream;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.SetBitsIterable;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.kge.KGEPredictResult;
import org.neo4j.gds.ml.kge.ScoreFunction;
import org.neo4j.gds.ml.kge.scorers.LinkScorer;
import org.neo4j.gds.ml.kge.scorers.LinkScorerFactory;
import org.neo4j.gds.similarity.nodesim.TopKMap;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.AutoCloseableThreadLocal;
import org.neo4j.gds.utils.CloseableThreadLocal;

public class TopKMapComputer
extends Algorithm<KGEPredictResult> {
    private final Graph graph;
    private final ProgressTracker progressTracker;
    private final BitSet sourceNodes;
    private final BitSet targetNodes;
    private final String nodeEmbeddingProperty;
    private final DoubleArrayList relationshipTypeEmbedding;
    private final Concurrency concurrency;
    private final int topK;
    private final ScoreFunction scoreFunction;
    private final boolean higherIsBetter;

    public TopKMapComputer(Graph graph, BitSet sourceNodes, BitSet targetNodes, String nodeEmbeddingProperty, List<Double> relationshipTypeEmbedding, ScoreFunction scoreFunction, int topK, Concurrency concurrency, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.progressTracker = progressTracker;
        this.sourceNodes = sourceNodes;
        this.targetNodes = targetNodes;
        this.nodeEmbeddingProperty = nodeEmbeddingProperty;
        this.relationshipTypeEmbedding = DoubleArrayList.from((double[])relationshipTypeEmbedding.stream().mapToDouble(Double::doubleValue).toArray());
        this.concurrency = concurrency;
        this.topK = topK;
        this.scoreFunction = scoreFunction;
        this.higherIsBetter = scoreFunction == ScoreFunction.DISTMULT;
    }

    public KGEPredictResult compute() {
        this.progressTracker.beginSubTask(this.estimateWorkload());
        TopKMap topKMap = new TopKMap(this.sourceNodes.capacity(), this.sourceNodes, Math.abs(this.topK), this.higherIsBetter);
        NodePropertyValues embeddings = this.graph.nodeProperties(this.nodeEmbeddingProperty);
        try (AutoCloseableThreadLocal threadLocalScorer = AutoCloseableThreadLocal.withInitial(() -> LinkScorerFactory.create(this.scoreFunction, embeddings, this.relationshipTypeEmbedding));){
            try (CloseableThreadLocal concurrentGraph = CloseableThreadLocal.withInitial(() -> ((Graph)this.graph).concurrentCopy());){
                ParallelUtil.parallelStreamConsume((BaseStream)new SetBitsIterable(this.sourceNodes).stream(), (Concurrency)this.concurrency, (TerminationFlag)this.terminationFlag, stream -> stream.forEach(node1 -> {
                    this.terminationFlag.assertRunning();
                    LongLongPredicate isCandidateLinkPredicate = this.isCandidateLink((Graph)concurrentGraph.get());
                    LinkScorer linkScorer = (LinkScorer)threadLocalScorer.get();
                    linkScorer.init(node1);
                    this.targetNodesStream().filter(node2 -> isCandidateLinkPredicate.apply(node1, node2)).forEach(node2 -> {
                        double similarity = linkScorer.computeScore(node2);
                        if (!Double.isNaN(similarity)) {
                            topKMap.put(node1, node2, similarity);
                        }
                    });
                }));
            }
            this.progressTracker.logProgress();
        }
        this.progressTracker.endSubTask();
        return KGEPredictResult.of(topKMap);
    }

    private LongStream targetNodesStream() {
        return new SetBitsIterable(this.targetNodes, 0L).stream();
    }

    private long estimateWorkload() {
        return this.sourceNodes.cardinality() * this.targetNodes.cardinality();
    }

    private LongLongPredicate isCandidateLink(Graph graph) {
        return (s, t) -> s != t && !graph.exists(s, t);
    }
}

