package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.splitting.SplitRelationships;
import org.neo4j.gds.ml.splitting.SplitRelationshipsBaseConfig;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/RelationshipSplitter.class */
public class RelationshipSplitter {
    private static final String SPLIT_ERROR_TEMPLATE = "%s graph contains no relationships. Consider increasing the `%s` or provide a larger graph";
    private final LinkPredictionSplitConfig splitConfig;
    private final ProgressTracker progressTracker;
    private final GraphStore graphStore;
    private final TerminationFlag terminationFlag;

    /* JADX INFO: Access modifiers changed from: package-private */
    public RelationshipSplitter(GraphStore graphStore, LinkPredictionSplitConfig linkPredictionSplitConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.graphStore = graphStore;
        this.splitConfig = linkPredictionSplitConfig;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    public void splitRelationships(Collection<RelationshipType> collection, Collection<NodeLabel> collection2, Optional<Long> optional, Optional<String> optional2) {
        this.progressTracker.beginSubTask("Split relationships");
        this.splitConfig.validateAgainstGraphStore(this.graphStore);
        RelationshipType testComplementRelationshipType = this.splitConfig.testComplementRelationshipType();
        relationshipSplit(this.splitConfig.testSplit(collection, optional, optional2), collection2);
        validateTestSplit(this.graphStore);
        relationshipSplit(this.splitConfig.trainSplit(List.of(this.splitConfig.testComplementRelationshipType()), optional, optional2), collection2);
        this.graphStore.deleteRelationships(testComplementRelationshipType);
        this.progressTracker.endSubTask("Split relationships");
    }

    private void validateTestSplit(GraphStore graphStore) {
        if (graphStore.getGraph(new RelationshipType[]{this.splitConfig.testRelationshipType()}).relationshipCount() <= 0) {
            throw new IllegalStateException(StringFormatting.formatWithLocale(SPLIT_ERROR_TEMPLATE, new Object[]{"Test", LinkPredictionSplitConfig.TEST_FRACTION_KEY}));
        }
    }

    private void relationshipSplit(SplitRelationshipsBaseConfig splitRelationshipsBaseConfig, Collection<NodeLabel> collection) {
        splitRelationshipsBaseConfig.graphStoreValidation(this.graphStore, collection, splitRelationshipsBaseConfig.internalRelationshipTypes(this.graphStore));
        Graph graph = this.graphStore.getGraph(collection, splitRelationshipsBaseConfig.internalRelationshipTypes(this.graphStore), Optional.ofNullable(splitRelationshipsBaseConfig.relationshipWeightProperty()));
        SplitRelationships splitRelationships = new SplitRelationships(graph, graph, splitRelationshipsBaseConfig);
        splitRelationships.setTerminationFlag(this.terminationFlag);
        SplitRelationshipGraphStoreMutator.mutate(this.graphStore, splitRelationships.compute(), splitRelationshipsBaseConfig);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation splitEstimation(LinkPredictionSplitConfig linkPredictionSplitConfig, List<String> list, Optional<String> optional) {
        List list2 = (List) list.stream().map(str -> {
            return str.equals("*") ? RelationshipType.ALL_RELATIONSHIPS : RelationshipType.of(str);
        }).collect(Collectors.toList());
        Optional<Long> empty = Optional.empty();
        MemoryEstimation build = MemoryEstimations.builder("Test/Test-complement split").add(SplitRelationships.estimate(linkPredictionSplitConfig.testSplit(list2, empty, optional))).build();
        return MemoryEstimations.builder("Split relationships").add(build).add(MemoryEstimations.builder("Train/Feature-input split").add(SplitRelationships.estimate(linkPredictionSplitConfig.trainSplit(List.of(linkPredictionSplitConfig.testComplementRelationshipType()), empty, optional))).build()).build();
    }
}
