/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.stream.Stream;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.PredictedLink;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LPNodeFilter;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionSimilarityComputer;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.similarity.knn.ImmutableKnnContext;
import org.neo4j.gds.similarity.knn.Knn;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnFactory;
import org.neo4j.gds.similarity.knn.NeighborFilterFactory;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;

public class ApproximateLinkPrediction
extends LinkPrediction {
    private final KnnBaseConfig knnConfig;
    private final TerminationFlag terminationFlag;

    public ApproximateLinkPrediction(Classifier classifier, LinkFeatureExtractor linkFeatureExtractor, Graph graph, LPNodeFilter sourceNodeFilter, LPNodeFilter targetNodeFilter, KnnBaseConfig knnConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        super(classifier, linkFeatureExtractor, graph, sourceNodeFilter, targetNodeFilter, knnConfig.concurrency(), progressTracker);
        this.knnConfig = knnConfig;
        this.terminationFlag = terminationFlag;
    }

    public static MemoryEstimation estimate(LinkPredictionPredictPipelineBaseConfig config) {
        KnnBaseConfig knnConfig = config.approximateConfig();
        MemoryEstimation knnEstimation = new KnnFactory().memoryEstimation(knnConfig);
        return MemoryEstimations.builder((String)ApproximateLinkPrediction.class.getSimpleName()).add(knnEstimation).build();
    }

    @Override
    LinkPredictionResult predictLinks(LinkPredictionSimilarityComputer linkPredictionSimilarityComputer) {
        Knn knn = Knn.create((Graph)this.graph, (KnnBaseConfig)this.knnConfig, (SimilarityComputer)linkPredictionSimilarityComputer, (NeighborFilterFactory)new LinkPredictionSimilarityComputer.LinkFilterFactory(this.graph, this.sourceNodeFilter, this.targetNodeFilter), (KnnContext)ImmutableKnnContext.of((ExecutorService)DefaultPool.INSTANCE, (ProgressTracker)this.progressTracker));
        knn.setTerminationFlag(this.terminationFlag);
        Knn.Result knnResult = knn.compute();
        return new Result(knnResult);
    }

    static class Result
    implements LinkPredictionResult {
        private final Knn.Result predictions;
        private final Map<String, Object> samplingStats;

        Result(Knn.Result knnResult) {
            this.predictions = knnResult;
            this.samplingStats = Map.of("strategy", "approximate", "linksConsidered", knnResult.nodePairsConsidered(), "ranIterations", knnResult.ranIterations(), "didConverge", knnResult.didConverge());
        }

        public Stream<PredictedLink> stream() {
            return this.predictions.streamSimilarityResult().map(i -> PredictedLink.of((long)i.sourceNodeId(), (long)i.targetNodeId(), (double)i.similarity));
        }

        public Map<String, Object> samplingStats() {
            return this.samplingStats;
        }
    }
}

