package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrain;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainResult;
import org.neo4j.gds.ml.util.TrainingSetWarnings;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineExecutor.class */
public class LinkPredictionTrainPipelineExecutor extends PipelineExecutor<LinkPredictionTrainConfig, LinkPredictionTrainingPipeline, LinkPredictionTrainPipelineResult> {
    private final RelationshipSplitter relationshipSplitter;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineExecutor$LinkPredictionTrainPipelineResult.class */
    public interface LinkPredictionTrainPipelineResult {
        Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model();

        TrainingStatistics trainingStatistics();
    }

    public LinkPredictionTrainPipelineExecutor(LinkPredictionTrainingPipeline linkPredictionTrainingPipeline, LinkPredictionTrainConfig linkPredictionTrainConfig, ExecutionContext executionContext, GraphStore graphStore, String str, ProgressTracker progressTracker) {
        super(linkPredictionTrainingPipeline, linkPredictionTrainConfig, executionContext, graphStore, str, progressTracker);
        this.relationshipSplitter = new RelationshipSplitter(str, linkPredictionTrainingPipeline.splitConfig(), executionContext, progressTracker);
    }

    public static Task progressTask(String str, final LinkPredictionTrainingPipeline linkPredictionTrainingPipeline) {
        return Tasks.task(str, new ArrayList<Task>() { // from class: org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainPipelineExecutor.1
            {
                add(Tasks.leaf("Split relationships"));
                add(Tasks.iterativeFixed("Execute node property steps", () -> {
                    return List.of(Tasks.leaf("Step"));
                }, linkPredictionTrainingPipeline.nodePropertySteps().size()));
                addAll(LinkPredictionTrain.progressTasks(linkPredictionTrainingPipeline.splitConfig().validationFolds(), linkPredictionTrainingPipeline.numberOfModelSelectionTrials()));
            }
        });
    }

    public static MemoryEstimation estimate(ModelCatalog modelCatalog, LinkPredictionTrainingPipeline linkPredictionTrainingPipeline, LinkPredictionTrainConfig linkPredictionTrainConfig) {
        PipelineExecutor.validateTrainingParameterSpace(linkPredictionTrainingPipeline);
        return MemoryEstimations.builder(LinkPredictionTrainPipelineExecutor.class.getSimpleName()).max("Pipeline execution", List.of(RelationshipSplitter.splitEstimation(linkPredictionTrainingPipeline.splitConfig(), linkPredictionTrainConfig.relationshipTypes()), PipelineExecutor.estimateNodePropertySteps(modelCatalog, linkPredictionTrainingPipeline.nodePropertySteps(), linkPredictionTrainConfig.nodeLabels(), List.of(linkPredictionTrainingPipeline.splitConfig().featureInputRelationshipType())), MemoryEstimations.builder().add("Train pipeline", LinkPredictionTrain.estimate(linkPredictionTrainingPipeline, linkPredictionTrainConfig)).build())).build();
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        this.relationshipSplitter.splitRelationships(this.graphStore, this.config.relationshipTypes(), this.config.nodeLabels(), this.config.randomSeed(), this.pipeline.relationshipWeightProperty());
        LinkPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        Collection nodeLabelIdentifiers = this.config.nodeLabelIdentifiers(this.graphStore);
        return Map.of(PipelineExecutor.DatasetSplits.TRAIN, ImmutableGraphFilter.of(nodeLabelIdentifiers, RelationshipType.listOf(new String[]{splitConfig.trainRelationshipType()})), PipelineExecutor.DatasetSplits.TEST, ImmutableGraphFilter.of(nodeLabelIdentifiers, RelationshipType.listOf(new String[]{splitConfig.testRelationshipType()})), PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of(nodeLabelIdentifiers, RelationshipType.listOf(new String[]{splitConfig.featureInputRelationshipType()})));
    }

    protected LinkPredictionTrainPipelineResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> map) {
        PipelineExecutor.validateTrainingParameterSpace(this.pipeline);
        PipelineExecutor.GraphFilter graphFilter = map.get(PipelineExecutor.DatasetSplits.TRAIN);
        PipelineExecutor.GraphFilter graphFilter2 = map.get(PipelineExecutor.DatasetSplits.TEST);
        Graph graph = this.graphStore.getGraph(graphFilter.nodeLabels(), graphFilter.relationshipTypes(), Optional.of("label"));
        Graph graph2 = this.graphStore.getGraph(graphFilter2.nodeLabels(), graphFilter2.relationshipTypes(), Optional.of("label"));
        TrainingSetWarnings.warnForSmallRelationshipSets(graph.relationshipCount(), graph2.relationshipCount(), this.pipeline.splitConfig().validationFolds(), this.progressTracker);
        LinkPredictionTrainResult compute = new LinkPredictionTrain(graph, graph2, this.pipeline, this.config, this.progressTracker, this.terminationFlag).compute();
        return ImmutableLinkPredictionTrainPipelineResult.of(Model.of(this.config.username(), this.config.modelName(), "LinkPrediction", this.schemaBeforeSteps, compute.classifier().data(), this.config, LinkPredictionModelInfo.of(compute.trainingStatistics().winningModelTestMetrics(), compute.trainingStatistics().winningModelOuterTrainMetrics(), compute.trainingStatistics().bestCandidate(), LinkPredictionPredictPipeline.from(this.pipeline))), compute.trainingStatistics());
    }

    private void removeDataSplitRelationships(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> map) {
        List list = (List) map.values().stream().flatMap(graphFilter -> {
            return graphFilter.relationshipTypes().stream();
        }).distinct().collect(Collectors.toList());
        GraphStore graphStore = this.graphStore;
        Objects.requireNonNull(graphStore);
        list.forEach(graphStore::deleteRelationships);
    }

    protected void cleanUpGraphStore(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> map) {
        removeDataSplitRelationships(map);
        super.cleanUpGraphStore(map);
    }

    /* renamed from: execute, reason: collision with other method in class */
    protected /* bridge */ /* synthetic */ Object m13execute(Map map) {
        return execute((Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter>) map);
    }
}
