/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Collection;
import java.util.stream.Stream;
import org.neo4j.gds.LoggingUtil;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.NewConfigFunction;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineAlgorithmFactory;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineExecutor;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineStreamConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.StreamResult;
import org.neo4j.logging.Log;

@GdsCallable(name="gds.beta.pipeline.linkPrediction.predict.stream", description="Predicts relationships for all non-connected node pairs based on a previously trained LinkPrediction model.", executionMode=ExecutionMode.STREAM)
public class LinkPredictionPipelineStreamSpec
implements AlgorithmSpec<LinkPredictionPredictPipelineExecutor, LinkPredictionResult, LinkPredictionPredictPipelineStreamConfig, Stream<StreamResult>, LinkPredictionPredictPipelineAlgorithmFactory<LinkPredictionPredictPipelineStreamConfig>> {
    public String name() {
        return "LinkPredictionPipelineStream";
    }

    public LinkPredictionPredictPipelineAlgorithmFactory<LinkPredictionPredictPipelineStreamConfig> algorithmFactory(ExecutionContext executionContext) {
        return new LinkPredictionPredictPipelineAlgorithmFactory<LinkPredictionPredictPipelineStreamConfig>(executionContext);
    }

    public NewConfigFunction<LinkPredictionPredictPipelineStreamConfig> newConfigFunction() {
        return LinkPredictionPredictPipelineStreamConfig::of;
    }

    public ComputationResultConsumer<LinkPredictionPredictPipelineExecutor, LinkPredictionResult, LinkPredictionPredictPipelineStreamConfig, Stream<StreamResult>> computationResultConsumer() {
        return (computationResult, executionContext) -> (Stream)LoggingUtil.runWithExceptionLogging((String)"Result streaming failed", (Log)executionContext.log(), () -> computationResult.result().map(result -> {
            GraphStore graphStore = computationResult.graphStore();
            Collection<NodeLabel> labelFilter = ((LinkPredictionPredictPipelineExecutor)computationResult.algorithm()).labelFilter().predictNodeLabels();
            Graph graph = graphStore.getGraph(labelFilter);
            return result.stream().map(predictedLink -> new StreamResult(graph.toOriginalNodeId(predictedLink.sourceId()), graph.toOriginalNodeId(predictedLink.targetId()), predictedLink.probability()));
        }).orElseGet(Stream::empty));
    }
}

