/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.train;

import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainPipelineExecutor;

public class LinkPredictionTrainPipelineAlgorithmFactory
extends GraphStoreAlgorithmFactory<LinkPredictionTrainPipelineExecutor, LinkPredictionTrainConfig> {
    private final ExecutionContext executionContext;
    private final String gdsVersion;

    LinkPredictionTrainPipelineAlgorithmFactory(ExecutionContext executionContext, String gdsVersion) {
        this.executionContext = executionContext;
        this.gdsVersion = gdsVersion;
    }

    public LinkPredictionTrainPipelineExecutor build(GraphStore graphStore, LinkPredictionTrainConfig trainConfig, ProgressTracker progressTracker) {
        LinkPredictionTrainingPipeline pipeline = (LinkPredictionTrainingPipeline)PipelineCatalog.getTyped((String)trainConfig.username(), (String)trainConfig.pipeline(), LinkPredictionTrainingPipeline.class);
        PipelineCompanion.validateMainMetric((TrainingPipeline)pipeline, (String)trainConfig.mainMetric().name());
        return new LinkPredictionTrainPipelineExecutor(pipeline, trainConfig, this.executionContext, graphStore, progressTracker);
    }

    public String taskName() {
        return "Link Prediction Train Pipeline";
    }

    public Task progressTask(GraphStore graphStore, LinkPredictionTrainConfig config) {
        long relationshipCount = config.internalRelationshipTypes(graphStore).stream().mapToLong(arg_0 -> ((GraphStore)graphStore).relationshipCount(arg_0)).sum();
        return LinkPredictionTrainPipelineExecutor.progressTask((String)this.taskName(), (LinkPredictionTrainingPipeline)((LinkPredictionTrainingPipeline)PipelineCatalog.getTyped((String)config.username(), (String)config.pipeline(), LinkPredictionTrainingPipeline.class)), (long)relationshipCount);
    }

    public MemoryEstimation memoryEstimation(LinkPredictionTrainConfig configuration) {
        LinkPredictionTrainingPipeline pipeline = (LinkPredictionTrainingPipeline)PipelineCatalog.getTyped((String)configuration.username(), (String)configuration.pipeline(), LinkPredictionTrainingPipeline.class);
        return LinkPredictionTrainPipelineExecutor.estimate((ExecutionContext)this.executionContext, (LinkPredictionTrainingPipeline)pipeline, (LinkPredictionTrainConfig)configuration);
    }

    public GraphDimensions estimatedGraphDimensionTransformer(GraphDimensions graphDimensions, LinkPredictionTrainConfig config) {
        LinkPredictionSplitConfig splitConfig = ((LinkPredictionTrainingPipeline)PipelineCatalog.getTyped((String)config.username(), (String)config.pipeline(), LinkPredictionTrainingPipeline.class)).splitConfig();
        return splitConfig.expectedGraphDimensions(graphDimensions, config.targetRelationshipType());
    }
}

