/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.kge;

import com.carrotsearch.hppc.BitSet;
import java.util.Optional;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.kge.KGEPredictBaseConfig;
import org.neo4j.gds.ml.kge.KGEPredictParameters;
import org.neo4j.gds.ml.kge.TopKMapComputer;
import org.neo4j.gds.similarity.filtering.NodeFilter;

public class KGEPredictAlgorithmFactory<CONFIG extends KGEPredictBaseConfig>
extends GraphStoreAlgorithmFactory<TopKMapComputer, CONFIG> {
    public TopKMapComputer build(GraphStore graphStore, KGEPredictParameters parameters, ProgressTracker progressTracker) {
        BitSet sourceNodes = new BitSet(graphStore.nodeCount());
        BitSet targetNodes = new BitSet(graphStore.nodeCount());
        Graph graph = graphStore.getGraph(parameters.relationshipTypesFilter(), Optional.empty());
        NodeFilter sourceNodeFilter = parameters.sourceNodeFilter().toNodeFilter((IdMap)graph);
        NodeFilter targetNodeFilter = parameters.targetNodeFilter().toNodeFilter((IdMap)graph);
        graph.forEachNode(node -> {
            if (sourceNodeFilter.test(node)) {
                sourceNodes.set(node);
            }
            if (targetNodeFilter.test(node)) {
                targetNodes.set(node);
            }
            return true;
        });
        return new TopKMapComputer(graph, sourceNodes, targetNodes, parameters.nodeEmbeddingProperty(), parameters.relationshipTypeEmbedding(), parameters.scoringFunction(), parameters.topK(), parameters.concurrency(), progressTracker);
    }

    public TopKMapComputer build(GraphStore graphStore, CONFIG configuration, ProgressTracker progressTracker) {
        return this.build(graphStore, configuration.toParameters(), progressTracker);
    }

    public String taskName() {
        return "KGEPredict";
    }
}

