/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.loading.CatalogRequest;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LPGraphStoreFilter;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LPGraphStoreFilterFactory;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineExecutor;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierFactory;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;

public class LinkPredictionPredictPipelineAlgorithmFactory<CONFIG extends LinkPredictionPredictPipelineBaseConfig>
extends GraphStoreAlgorithmFactory<LinkPredictionPredictPipelineExecutor, CONFIG> {
    private final ExecutionContext executionContext;
    private final ModelCatalog modelCatalog;

    LinkPredictionPredictPipelineAlgorithmFactory(ExecutionContext executionContext) {
        this.executionContext = executionContext;
        this.modelCatalog = executionContext.modelCatalog();
    }

    public Task progressTask(GraphStore graphStore, CONFIG config) {
        LinkPredictionPredictPipeline pipeline = ((LinkPredictionModelInfo)LinkPredictionPipelineCompanion.getTrainedLPPipelineModel(this.modelCatalog, config.modelName(), config.username()).customInfo()).pipeline();
        return LinkPredictionPredictPipelineExecutor.progressTask(this.taskName(), pipeline, graphStore, config);
    }

    public String taskName() {
        return "Link Prediction Predict Pipeline";
    }

    public LinkPredictionPredictPipelineExecutor build(GraphStore graphStore, CONFIG configuration, ProgressTracker progressTracker) {
        Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model = LinkPredictionPipelineCompanion.getTrainedLPPipelineModel(this.modelCatalog, configuration.modelName(), configuration.username());
        LinkPredictionTrainConfig trainConfig = (LinkPredictionTrainConfig)model.trainConfig();
        LPGraphStoreFilter lpGraphStoreFilter = LPGraphStoreFilterFactory.generate(trainConfig, configuration, graphStore, progressTracker);
        return new LinkPredictionPredictPipelineExecutor(((LinkPredictionModelInfo)model.customInfo()).pipeline(), ClassifierFactory.create((Classifier.ClassifierData)((Classifier.ClassifierData)model.data())), lpGraphStoreFilter, (LinkPredictionPredictPipelineBaseConfig)configuration, this.executionContext, graphStore, progressTracker);
    }

    public MemoryEstimation memoryEstimation(CONFIG configuration) {
        Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model = LinkPredictionPipelineCompanion.getTrainedLPPipelineModel(this.modelCatalog, configuration.modelName(), configuration.username());
        LinkPredictionPredictPipeline linkPredictionPipeline = ((LinkPredictionModelInfo)model.customInfo()).pipeline();
        return LinkPredictionPredictPipelineExecutor.estimate(this.modelCatalog, linkPredictionPipeline, configuration, (Classifier.ClassifierData)model.data());
    }

    public GraphDimensions estimatedGraphDimensionTransformer(GraphDimensions graphDimensions, CONFIG config) {
        Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model = LinkPredictionPipelineCompanion.getTrainedLPPipelineModel(this.modelCatalog, config.modelName(), config.username());
        if (config.graphName().equals("__ANONYMOUS_GRAPH__")) {
            return graphDimensions;
        }
        GraphStore graphStore = GraphStoreCatalog.get((CatalogRequest)CatalogRequest.of((String)config.username(), (DatabaseId)this.executionContext.databaseId()), (String)config.graphName()).graphStore();
        LPGraphStoreFilter lpNodeLabelFilter = LPGraphStoreFilterFactory.generate((LinkPredictionTrainConfig)model.trainConfig(), config, graphStore, ProgressTracker.NULL_TRACKER);
        return GraphDimensions.builder().from(graphDimensions).nodeCount(graphStore.getGraph(lpNodeLabelFilter.nodePropertyStepsBaseLabels()).nodeCount()).build();
    }
}

