/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ApproximateLinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ExhaustiveLinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LPGraphStoreFilter;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LPNodeFilter;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierFactory;
import org.neo4j.gds.ml.pipeline.ImmutablePipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.PredictPipelineExecutor;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnFactory;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

public class LinkPredictionPredictPipelineExecutor
extends PredictPipelineExecutor<LinkPredictionPredictPipelineBaseConfig, LinkPredictionPredictPipeline, LinkPredictionResult> {
    private final Classifier classifier;
    private final LPGraphStoreFilter graphStoreFilter;

    public LinkPredictionPredictPipelineExecutor(LinkPredictionPredictPipeline pipeline, Classifier classifier, LPGraphStoreFilter graphStoreFilter, LinkPredictionPredictPipelineBaseConfig config, ExecutionContext executionContext, GraphStore graphStore, ProgressTracker progressTracker) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, progressTracker);
        this.classifier = classifier;
        this.graphStoreFilter = graphStoreFilter;
    }

    LPGraphStoreFilter labelFilter() {
        return this.graphStoreFilter;
    }

    protected LinkPredictionResult execute() {
        Graph graph = this.graphStore.getGraph(this.graphStoreFilter.predictNodeLabels(), this.graphStoreFilter.predictRelationshipTypes(), Optional.empty());
        LinkFeatureExtractor linkFeatureExtractor = LinkFeatureExtractor.of((Graph)graph, (List)((LinkPredictionPredictPipeline)this.pipeline).featureSteps());
        LinkPrediction linkPrediction = this.getLinkPredictionStrategy(graph, ((LinkPredictionPredictPipelineBaseConfig)this.config).isApproximateStrategy(), linkFeatureExtractor);
        return linkPrediction.compute();
    }

    protected PipelineGraphFilter nodePropertyStepFilter() {
        return ImmutablePipelineGraphFilter.builder().nodeLabels(this.graphStoreFilter.nodePropertyStepsBaseLabels()).relationshipTypes(this.graphStoreFilter.predictRelationshipTypes()).build();
    }

    public static Task progressTask(String taskName, LinkPredictionPredictPipeline pipeline, GraphStore graphStore, LinkPredictionPredictPipelineBaseConfig config) {
        return Tasks.task((String)taskName, (Task)NodePropertyStepExecutor.tasks((List)pipeline.nodePropertySteps(), (long)graphStore.relationshipCount()), (Task[])new Task[]{config.isApproximateStrategy() ? Tasks.task((String)"Approximate link prediction", (Task)KnnFactory.knnTaskTree((Graph)graphStore.getUnion(), (KnnBaseConfig)config.approximateConfig()), (Task[])new Task[0]) : Tasks.leaf((String)"Exhaustive link prediction", (long)(graphStore.getUnion().nodeCount() * graphStore.getUnion().nodeCount() / 2L))});
    }

    public static MemoryEstimation estimate(ModelCatalog modelCatalog, LinkPredictionPredictPipeline pipeline, LinkPredictionPredictPipelineBaseConfig configuration, Classifier.ClassifierData classifierData) {
        MemoryEstimation maxOverNodePropertySteps = NodePropertyStepExecutor.estimateNodePropertySteps((ModelCatalog)modelCatalog, (String)configuration.username(), (List)pipeline.nodePropertySteps(), configuration.nodeLabels(), configuration.relationshipTypes());
        MemoryEstimation strategyEstimation = configuration.isApproximateStrategy() ? ApproximateLinkPrediction.estimate(configuration) : ExhaustiveLinkPrediction.estimate(configuration, classifierData.featureDimension());
        MemoryRange classificationRange = classifierData.trainerMethod() == TrainingMethod.LogisticRegression ? MemoryRange.of((long)0L) : ClassifierFactory.runtimeOverheadMemoryEstimation((TrainingMethod)classifierData.trainerMethod(), (int)1, (int)2, (int)classifierData.featureDimension(), (boolean)true);
        MemoryEstimation predictEstimation = MemoryEstimations.builder((String)"Model prediction").add("Strategy runtime", strategyEstimation).add(MemoryEstimations.of((String)"Classifier runtime", (MemoryRange)classificationRange)).build();
        return MemoryEstimations.builder((String)LinkPredictionPredictPipelineExecutor.class.getSimpleName()).max("Pipeline execution", List.of(maxOverNodePropertySteps, predictEstimation)).build();
    }

    private LinkPrediction getLinkPredictionStrategy(Graph graph, boolean isApproximateStrategy, LinkFeatureExtractor linkFeatureExtractor) {
        if (linkFeatureExtractor.featureDimension() != this.classifier.data().featureDimension()) {
            Set inputNodeProperties = ((LinkPredictionPredictPipeline)this.pipeline).featureSteps().stream().flatMap(step -> step.inputNodeProperties().stream()).collect(Collectors.toSet());
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Model expected link features to have a total dimension of `%d`, but got `%d`. This indicates the dimension of the node-properties %s differ between the input and the original train graph.", (Object[])new Object[]{this.classifier.data().featureDimension(), linkFeatureExtractor.featureDimension(), StringJoining.join(inputNodeProperties)}));
        }
        Graph sourceNodes = this.graphStore.getGraph(this.graphStoreFilter.sourceNodeLabels());
        Graph targetNodes = this.graphStore.getGraph(this.graphStoreFilter.targetNodeLabels());
        LPNodeFilter sourceNodeFilter = LPNodeFilter.of(graph, (IdMap)sourceNodes);
        LPNodeFilter targetNodeFilter = LPNodeFilter.of(graph, (IdMap)targetNodes);
        if (isApproximateStrategy) {
            return new ApproximateLinkPrediction(this.classifier, linkFeatureExtractor, graph, sourceNodeFilter, targetNodeFilter, ((LinkPredictionPredictPipelineBaseConfig)this.config).approximateConfig(), this.progressTracker, this.terminationFlag);
        }
        return new ExhaustiveLinkPrediction(this.classifier, linkFeatureExtractor, graph, sourceNodeFilter, targetNodeFilter, ((LinkPredictionPredictPipelineBaseConfig)this.config).concurrency(), ((LinkPredictionPredictPipelineBaseConfig)this.config).topN().orElseThrow(), ((LinkPredictionPredictPipelineBaseConfig)this.config).thresholdOrDefault(), this.progressTracker, this.terminationFlag);
    }
}

