package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Map;
import java.util.stream.Stream;
import org.HdrHistogram.ConcurrentDoubleHistogram;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.MutateComputationResultConsumer;
import org.neo4j.gds.MutateProc;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.Aggregation;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.loading.SingleTypeRelationships;
import org.neo4j.gds.core.loading.construction.GraphFactory;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.result.AbstractResultBuilder;
import org.neo4j.gds.result.HistogramUtils;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.gds.results.StandardMutateResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@GdsCallable(name = "gds.beta.pipeline.linkPrediction.predict.mutate", description = LinkPredictionPipelineCompanion.PREDICT_DESCRIPTION, executionMode = ExecutionMode.MUTATE_RELATIONSHIP)
/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateProc.class */
public class LinkPredictionPipelineMutateProc extends MutateProc<LinkPredictionPredictPipelineExecutor, LinkPredictionResult, MutateResult, LinkPredictionPredictPipelineMutateConfig> {

    /* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateProc$MutateResult.class */
    public static final class MutateResult extends StandardMutateResult {
        public final long relationshipsWritten;
        public final Map<String, Object> probabilityDistribution;
        public final Map<String, Object> samplingStats;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateProc$MutateResult$Builder.class */
        public static class Builder extends AbstractResultBuilder<MutateResult> {
            private Map<String, Object> samplingStats = null;

            @Nullable
            private ConcurrentDoubleHistogram histogram = null;

            Builder() {
            }

            /* renamed from: build, reason: merged with bridge method [inline-methods] */
            public MutateResult m5build() {
                return new MutateResult(this.preProcessingMillis, this.computeMillis, this.mutateMillis, this.relationshipsWritten, this.config.toMap(), this.histogram == null ? Map.of() : HistogramUtils.similaritySummary(this.histogram), this.samplingStats);
            }

            Builder withHistogram() {
                if (this.histogram != null) {
                    return this;
                }
                this.histogram = new ConcurrentDoubleHistogram(5);
                return this;
            }

            void recordHistogramValue(double d) {
                if (this.histogram == null) {
                    return;
                }
                this.histogram.recordValue(d);
            }

            Builder withSamplingStats(Map<String, Object> map) {
                this.samplingStats = map;
                return this;
            }
        }

        MutateResult(long j, long j2, long j3, long j4, Map<String, Object> map, Map<String, Object> map2, Map<String, Object> map3) {
            super(j, j2, 0L, j3, map);
            this.relationshipsWritten = j4;
            this.probabilityDistribution = map2;
            this.samplingStats = map3;
        }
    }

    @Procedure(name = "gds.beta.pipeline.linkPrediction.predict.mutate", mode = Mode.READ)
    @Description(LinkPredictionPipelineCompanion.PREDICT_DESCRIPTION)
    public Stream<MutateResult> mutate(@Name("graphName") String str, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) {
        PipelineCompanion.preparePipelineConfig(str, map);
        return mutate(compute(str, map));
    }

    @Procedure(name = "gds.beta.pipeline.linkPrediction.predict.mutate.estimate", mode = Mode.READ)
    @Description(LinkPredictionPipelineCompanion.ESTIMATE_PREDICT_DESCRIPTION)
    public Stream<MemoryEstimateResult> estimate(@Name("graphNameOrConfiguration") Object obj, @Name("algoConfiguration") Map<String, Object> map) {
        PipelineCompanion.preparePipelineConfig(obj, map);
        return computeEstimate(obj, map);
    }

    protected AbstractResultBuilder<MutateResult> resultBuilder(ComputationResult<LinkPredictionPredictPipelineExecutor, LinkPredictionResult, LinkPredictionPredictPipelineMutateConfig> computationResult, ExecutionContext executionContext) {
        MutateResult.Builder withSamplingStats = new MutateResult.Builder().withSamplingStats(((LinkPredictionResult) computationResult.result()).samplingStats());
        if (executionContext.returnColumns().contains("probabilityDistribution")) {
            withSamplingStats.withHistogram();
        }
        return withSamplingStats;
    }

    /* renamed from: computationResultConsumer, reason: merged with bridge method [inline-methods] */
    public MutateComputationResultConsumer<LinkPredictionPredictPipelineExecutor, LinkPredictionResult, LinkPredictionPredictPipelineMutateConfig, MutateResult> m3computationResultConsumer() {
        return new MutateComputationResultConsumer<LinkPredictionPredictPipelineExecutor, LinkPredictionResult, LinkPredictionPredictPipelineMutateConfig, MutateResult>(this::resultBuilder) { // from class: org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPipelineMutateProc.1
            protected void updateGraphStore(AbstractResultBuilder<?> abstractResultBuilder, ComputationResult<LinkPredictionPredictPipelineExecutor, LinkPredictionResult, LinkPredictionPredictPipelineMutateConfig> computationResult, ExecutionContext executionContext) {
                Graph graph = computationResult.graphStore().getGraph(computationResult.algorithm().labelFilter().predictNodeLabels());
                LinkPredictionPredictPipelineMutateConfig linkPredictionPredictPipelineMutateConfig = (LinkPredictionPredictPipelineMutateConfig) computationResult.config();
                int concurrency = linkPredictionPredictPipelineMutateConfig.concurrency();
                RelationshipsBuilder build = GraphFactory.initRelationshipsBuilder().aggregation(Aggregation.SINGLE).nodes(graph).relationshipType(RelationshipType.of(linkPredictionPredictPipelineMutateConfig.mutateRelationshipType())).orientation(Orientation.UNDIRECTED).addPropertyConfig(GraphFactory.PropertyConfig.of(((LinkPredictionPredictPipelineMutateConfig) computationResult.config()).mutateProperty())).concurrency(concurrency).executorService(Pools.DEFAULT).build();
                MutateResult.Builder builder = (MutateResult.Builder) abstractResultBuilder;
                ParallelUtil.parallelStreamConsume(((LinkPredictionResult) computationResult.result()).stream(), concurrency, stream -> {
                    stream.forEach(predictedLink -> {
                        build.addFromInternal(graph.toRootNodeId(predictedLink.sourceId()), graph.toRootNodeId(predictedLink.targetId()), predictedLink.probability());
                        builder.recordHistogramValue(predictedLink.probability());
                    });
                });
                SingleTypeRelationships build2 = build.build();
                computationResult.graphStore().addRelationshipType(build2);
                abstractResultBuilder.withRelationshipsWritten(build2.topology().elementCount());
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: newConfig, reason: merged with bridge method [inline-methods] */
    public LinkPredictionPredictPipelineMutateConfig m2newConfig(String str, CypherMapWrapper cypherMapWrapper) {
        return LinkPredictionPredictPipelineMutateConfig.of(str, cypherMapWrapper);
    }

    /* renamed from: algorithmFactory, reason: merged with bridge method [inline-methods] */
    public GraphStoreAlgorithmFactory<LinkPredictionPredictPipelineExecutor, LinkPredictionPredictPipelineMutateConfig> m4algorithmFactory() {
        return new LinkPredictionPredictPipelineAlgorithmFactory(executionContext(), modelCatalog());
    }

    public AlgorithmSpec<LinkPredictionPredictPipelineExecutor, LinkPredictionResult, LinkPredictionPredictPipelineMutateConfig, Stream<MutateResult>, AlgorithmFactory<?, LinkPredictionPredictPipelineExecutor, LinkPredictionPredictPipelineMutateConfig>> withModelCatalog(ModelCatalog modelCatalog) {
        setModelCatalog(modelCatalog);
        return this;
    }
}
