package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.Optional;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.LeafTask;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.pipeline.NonEmptySetValidation;
import org.neo4j.gds.ml.pipeline.linkPipeline.ExpectedSetSizes;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.splitting.SplitRelationships;
import org.neo4j.gds.ml.splitting.SplitRelationshipsBaseConfig;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/RelationshipSplitter.class */
public class RelationshipSplitter {
    private final LinkPredictionSplitConfig splitConfig;
    private final ProgressTracker progressTracker;
    private final GraphStore graphStore;
    private final TerminationFlag terminationFlag;

    public RelationshipSplitter(GraphStore graphStore, LinkPredictionSplitConfig linkPredictionSplitConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.graphStore = graphStore;
        this.splitConfig = linkPredictionSplitConfig;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotNull
    public static LeafTask progressTask(ExpectedSetSizes expectedSetSizes) {
        return Tasks.leaf("Split relationships", expectedSetSizes.trainSize() + expectedSetSizes.featureInputSize() + expectedSetSizes.testSize() + expectedSetSizes.testComplementSize());
    }

    public void splitRelationships(RelationshipType relationshipType, String str, String str2, Optional<Long> optional, Optional<String> optional2) {
        this.progressTracker.beginSubTask("Split relationships");
        this.splitConfig.validateAgainstGraphStore(this.graphStore, relationshipType);
        if (str.equals("*") || str2.equals("*")) {
            this.progressTracker.logWarning(StringFormatting.formatWithLocale("Using %s for the `sourceNodeLabel` or `targetNodeLabel` results in not ideal negative link sampling.", new Object[]{"*"}));
        }
        RelationshipType testComplementRelationshipType = this.splitConfig.testComplementRelationshipType();
        relationshipSplit(this.splitConfig.testSplit(relationshipType, str, str2, optional, optional2));
        validateTestSplit(this.graphStore);
        relationshipSplit(this.splitConfig.trainSplit(relationshipType, str, str2, optional, optional2));
        validateTrainSplit(this.graphStore);
        this.graphStore.deleteRelationships(testComplementRelationshipType);
        this.progressTracker.endSubTask("Split relationships");
    }

    private void validateTestSplit(GraphStore graphStore) {
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.testRelationshipType()), 1L, "test", "`testFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.testComplementRelationshipType()), 3L, "test-complement", "`testFraction` is too high");
    }

    private void validateTrainSplit(GraphStore graphStore) {
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.trainRelationshipType()), 2L, "train", "`trainFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.featureInputRelationshipType()), 1L, "feature-input", "`trainFraction` is too high");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.trainRelationshipType()) / this.splitConfig.validationFolds(), 1L, "validation", "`validationFolds` is too high or the `trainFraction` too low");
    }

    private void relationshipSplit(SplitRelationshipsBaseConfig splitRelationshipsBaseConfig) {
        splitRelationshipsBaseConfig.graphStoreValidation(this.graphStore, splitRelationshipsBaseConfig.nodeLabelIdentifiers(this.graphStore), splitRelationshipsBaseConfig.internalRelationshipTypes(this.graphStore));
        SplitRelationships of = SplitRelationships.of(this.graphStore, splitRelationshipsBaseConfig);
        of.setTerminationFlag(this.terminationFlag);
        SplitRelationshipGraphStoreMutator.mutate(this.graphStore, of.compute(), splitRelationshipsBaseConfig);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation splitEstimation(LinkPredictionSplitConfig linkPredictionSplitConfig, String str, Optional<String> optional, String str2, String str3) {
        RelationshipType of = str.equals("*") ? RelationshipType.ALL_RELATIONSHIPS : RelationshipType.of(str);
        Optional<Long> empty = Optional.empty();
        return MemoryEstimations.builder("Split relationships").add(MemoryEstimations.builder("Test/Test-complement split").add(SplitRelationships.estimate(linkPredictionSplitConfig.testSplit(of, str2, str3, empty, optional))).build()).add(MemoryEstimations.builder("Train/Feature-input split").add(SplitRelationships.estimate(linkPredictionSplitConfig.trainSplit(of, str2, str3, empty, optional))).build()).build();
    }
}
