/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.classification.predict;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.neo4j.gds.LoggingUtil;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.ml.models.BaseModelData;
import org.neo4j.gds.ml.pipeline.PipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineResult;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPredictPipelineAlgorithmFactory;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPredictPipelineExecutor;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPredictPipelineStreamConfig;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationStreamResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;

@GdsCallable(name="gds.beta.pipeline.nodeClassification.predict.stream", description="Predicts classes for all nodes based on a previously trained pipeline model", executionMode=ExecutionMode.STREAM)
public class NodeClassificationPipelineStreamSpec
implements AlgorithmSpec<NodeClassificationPredictPipelineExecutor, NodeClassificationPipelineResult, NodeClassificationPredictPipelineStreamConfig, Stream<NodeClassificationStreamResult>, NodeClassificationPredictPipelineAlgorithmFactory<NodeClassificationPredictPipelineStreamConfig>> {
    public String name() {
        return "NodeClassificationPipelineStream";
    }

    public NodeClassificationPredictPipelineAlgorithmFactory<NodeClassificationPredictPipelineStreamConfig> algorithmFactory(ExecutionContext executionContext) {
        return new NodeClassificationPredictPipelineAlgorithmFactory<NodeClassificationPredictPipelineStreamConfig>(executionContext);
    }

    public NewConfigFunction<NodeClassificationPredictPipelineStreamConfig> newConfigFunction() {
        return NodeClassificationPredictPipelineStreamConfig::of;
    }

    public ComputationResultConsumer<NodeClassificationPredictPipelineExecutor, NodeClassificationPipelineResult, NodeClassificationPredictPipelineStreamConfig, Stream<NodeClassificationStreamResult>> computationResultConsumer() {
        return (computationResult, executionContext) -> (Stream)LoggingUtil.runWithExceptionLogging((String)"Result streaming failed", (Log)executionContext.log(), () -> computationResult.result().map(result -> {
            PipelineGraphFilter pipelineGraphFilter = ((NodeClassificationPredictPipelineExecutor)computationResult.algorithm()).nodePropertyStepFilter();
            Graph graph = computationResult.graphStore().getGraph(pipelineGraphFilter.nodeLabels());
            HugeLongArray predictedClasses = result.predictedClasses();
            Optional<HugeObjectArray<double[]>> predictedProbabilities = result.predictedProbabilities();
            return LongStream.range(0L, graph.nodeCount()).mapToObj(nodeId -> new NodeClassificationStreamResult(graph.toOriginalNodeId(nodeId), predictedClasses.get(nodeId), NodeClassificationPipelineStreamSpec.nodePropertiesAsList(predictedProbabilities, nodeId)));
        }).orElseGet(Stream::empty));
    }

    private static List<Double> nodePropertiesAsList(Optional<HugeObjectArray<double[]>> predictedProbabilities, long nodeId) {
        return predictedProbabilities.map(p -> {
            double[] values = (double[])p.get(nodeId);
            return Arrays.stream(values).boxed().collect(Collectors.toList());
        }).orElse(null);
    }

    public void preProcessConfig(Map<String, Object> userInput, ExecutionContext executionContext) {
        if (!userInput.containsKey("modelName")) {
            return;
        }
        Object modelName = userInput.get("modelName");
        Model model = executionContext.modelCatalog().get(executionContext.username(), (String)modelName, BaseModelData.class, NodeClassificationPipelineTrainConfig.class, Model.CustomInfo.class);
        if (!userInput.containsKey("targetNodeLabels")) {
            userInput.put("targetNodeLabels", ((NodeClassificationPipelineTrainConfig)model.trainConfig()).targetNodeLabels());
        }
        if (!userInput.containsKey("relationshipTypes")) {
            userInput.put("relationshipTypes", ((NodeClassificationPipelineTrainConfig)model.trainConfig()).relationshipTypes());
        }
    }
}

