package org.neo4j.gds.ml.pipeline.node.classification.predict;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.StreamProc;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.ml.nodeClassification.NodeClassificationPredict;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@GdsCallable(name = "gds.beta.pipeline.nodeClassification.predict.stream", description = org.neo4j.gds.ml.pipeline.node.classification.NodeClassificationPipelineCompanion.PREDICT_DESCRIPTION, executionMode = ExecutionMode.STREAM)
/* loaded from: input_file:org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineStreamProc.class */
public class NodeClassificationPipelineStreamProc extends StreamProc<NodeClassificationPredictPipelineExecutor, NodeClassificationPredict.NodeClassificationResult, NodeClassificationStreamResult, NodeClassificationPredictPipelineStreamConfig> {

    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineStreamProc$NodeClassificationStreamResult.class */
    public static final class NodeClassificationStreamResult {
        public long nodeId;
        public long predictedClass;
        public List<Double> predictedProbabilities;

        public NodeClassificationStreamResult(long j, long j2, List<Double> list) {
            this.nodeId = j;
            this.predictedClass = j2;
            this.predictedProbabilities = list;
        }
    }

    @Procedure(name = "gds.beta.pipeline.nodeClassification.predict.stream", mode = Mode.READ)
    @Description(org.neo4j.gds.ml.pipeline.node.classification.NodeClassificationPipelineCompanion.PREDICT_DESCRIPTION)
    public Stream<NodeClassificationStreamResult> mutate(@Name("graphName") String str, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) {
        PipelineCompanion.preparePipelineConfig(str, map);
        return stream(compute(str, map));
    }

    @Procedure(name = "gds.beta.pipeline.nodeClassification.predict.stream.estimate", mode = Mode.READ)
    @Description(org.neo4j.gds.ml.pipeline.node.classification.NodeClassificationPipelineCompanion.ESTIMATE_PREDICT_DESCRIPTION)
    public Stream<MemoryEstimateResult> estimate(@Name("graphNameOrConfiguration") Object obj, @Name("algoConfiguration") Map<String, Object> map) {
        PipelineCompanion.preparePipelineConfig(obj, map);
        return computeEstimate(obj, map);
    }

    protected Stream<NodeClassificationStreamResult> stream(ComputationResult<NodeClassificationPredictPipelineExecutor, NodeClassificationPredict.NodeClassificationResult, NodeClassificationPredictPipelineStreamConfig> computationResult) {
        return (Stream) runWithExceptionLogging("Graph streaming failed", () -> {
            Graph graph = computationResult.graph();
            NodeClassificationPredict.NodeClassificationResult nodeClassificationResult = (NodeClassificationPredict.NodeClassificationResult) computationResult.result();
            HugeLongArray predictedClasses = nodeClassificationResult.predictedClasses();
            Optional predictedProbabilities = nodeClassificationResult.predictedProbabilities();
            return LongStream.range(0L, graph.nodeCount()).boxed().map(l -> {
                return new NodeClassificationStreamResult(graph.toOriginalNodeId(l.longValue()), predictedClasses.get(l.longValue()), nodePropertiesAsList(predictedProbabilities, l.longValue()));
            });
        });
    }

    private static List<Double> nodePropertiesAsList(Optional<HugeObjectArray<double[]>> optional, long j) {
        return (List) optional.map(hugeObjectArray -> {
            return (List) Arrays.stream((double[]) hugeObjectArray.get(j)).boxed().collect(Collectors.toList());
        }).orElse(null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: streamResult, reason: merged with bridge method [inline-methods] */
    public NodeClassificationStreamResult m18streamResult(long j, long j2, NodePropertyValues nodePropertyValues) {
        throw new UnsupportedOperationException("NodeClassification handles result building individually.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: newConfig, reason: merged with bridge method [inline-methods] */
    public NodeClassificationPredictPipelineStreamConfig m19newConfig(String str, CypherMapWrapper cypherMapWrapper) {
        return NodeClassificationPredictPipelineStreamConfig.of(str, cypherMapWrapper);
    }

    /* renamed from: algorithmFactory, reason: merged with bridge method [inline-methods] */
    public GraphStoreAlgorithmFactory<NodeClassificationPredictPipelineExecutor, NodeClassificationPredictPipelineStreamConfig> m20algorithmFactory() {
        return new NodeClassificationPredictPipelineAlgorithmFactory(executionContext(), modelCatalog());
    }

    public AlgorithmSpec<NodeClassificationPredictPipelineExecutor, NodeClassificationPredict.NodeClassificationResult, NodeClassificationPredictPipelineStreamConfig, Stream<NodeClassificationStreamResult>, AlgorithmFactory<?, NodeClassificationPredictPipelineExecutor, NodeClassificationPredictPipelineStreamConfig>> withModelCatalog(ModelCatalog modelCatalog) {
        setModelCatalog(modelCatalog);
        return this;
    }
}
