package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.Collection;
import java.util.List;
import java.util.Optional;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.config.ElementTypeValidator;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.utils.progress.tasks.LeafTask;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.negativeSampling.NegativeSampler;
import org.neo4j.gds.ml.pipeline.NonEmptySetValidation;
import org.neo4j.gds.ml.pipeline.linkPipeline.ExpectedSetSizes;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.ml.splitting.UndirectedEdgeSplitter;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionRelationshipSampler.class */
public class LinkPredictionRelationshipSampler {
    private final LinkPredictionSplitConfig splitConfig;
    private final LinkPredictionTrainConfig trainConfig;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private final GraphStore graphStore;

    public LinkPredictionRelationshipSampler(GraphStore graphStore, LinkPredictionSplitConfig linkPredictionSplitConfig, LinkPredictionTrainConfig linkPredictionTrainConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.graphStore = graphStore;
        this.splitConfig = linkPredictionSplitConfig;
        this.trainConfig = linkPredictionTrainConfig;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @NotNull
    public static LeafTask progressTask(ExpectedSetSizes expectedSetSizes) {
        return Tasks.leaf("Split relationships", expectedSetSizes.trainSize() + expectedSetSizes.featureInputSize() + expectedSetSizes.testSize() + expectedSetSizes.testComplementSize());
    }

    public void splitAndSampleRelationships(Optional<String> optional) {
        this.progressTracker.beginSubTask("Split relationships");
        this.splitConfig.validateAgainstGraphStore(this.graphStore, this.trainConfig.internalTargetRelationshipType());
        if (this.trainConfig.sourceNodeLabel().equals("*") || this.trainConfig.targetNodeLabel().equals("*")) {
            this.progressTracker.logWarning(StringFormatting.formatWithLocale("Using %s for the `sourceNodeLabel` or `targetNodeLabel` results in not ideal negative link sampling.", new Object[]{"*"}));
        }
        RelationshipType testComplementRelationshipType = this.splitConfig.testComplementRelationshipType();
        Collection resolve = ElementTypeValidator.resolve(this.graphStore, List.of(this.trainConfig.sourceNodeLabel()));
        Collection resolve2 = ElementTypeValidator.resolve(this.graphStore, List.of(this.trainConfig.targetNodeLabel()));
        Graph graph = this.graphStore.getGraph(resolve);
        Graph graph2 = this.graphStore.getGraph(resolve2);
        Collection nodeLabelIdentifiers = this.trainConfig.nodeLabelIdentifiers(this.graphStore);
        Graph graph3 = this.graphStore.getGraph(nodeLabelIdentifiers, this.trainConfig.internalRelationshipTypes(this.graphStore), optional);
        this.terminationFlag.assertRunning();
        EdgeSplitter.SplitResult split = split(graph, graph2, graph3, optional, this.splitConfig.testRelationshipType(), this.splitConfig.testComplementRelationshipType(), this.splitConfig.testFraction());
        Graph graph4 = this.graphStore.getGraph(nodeLabelIdentifiers, List.of(this.splitConfig.testComplementRelationshipType()), optional);
        this.terminationFlag.assertRunning();
        EdgeSplitter.SplitResult split2 = split(graph, graph2, graph4, optional, this.splitConfig.trainRelationshipType(), this.splitConfig.featureInputRelationshipType(), this.splitConfig.trainFraction());
        NegativeSampler of = NegativeSampler.of(this.graphStore, graph3, nodeLabelIdentifiers, this.splitConfig.negativeRelationshipType(), this.splitConfig.negativeSamplingRatio(), split.selectedRelCount(), split2.selectedRelCount(), graph, graph2, resolve, resolve2, this.trainConfig.randomSeed());
        this.terminationFlag.assertRunning();
        of.produceNegativeSamples(split.selectedRels(), split2.selectedRels());
        this.graphStore.addRelationshipType(split.selectedRels().build());
        this.graphStore.addRelationshipType(split2.selectedRels().build());
        validateTestSplit(this.graphStore);
        validateTrainSplit(this.graphStore);
        this.graphStore.deleteRelationships(testComplementRelationshipType);
        this.progressTracker.endSubTask("Split relationships");
    }

    private EdgeSplitter.SplitResult split(IdMap idMap, IdMap idMap2, Graph graph, Optional<String> optional, RelationshipType relationshipType, RelationshipType relationshipType2, double d) {
        if (!graph.schema().isUndirected()) {
            throw new IllegalArgumentException("EdgeSplitter requires graph to be UNDIRECTED");
        }
        EdgeSplitter.SplitResult splitPositiveExamples = new UndirectedEdgeSplitter(this.trainConfig.randomSeed(), this.graphStore.nodes(), idMap, idMap2, relationshipType, relationshipType2, this.trainConfig.concurrency()).splitPositiveExamples(graph, d, optional);
        this.graphStore.addRelationshipType(splitPositiveExamples.remainingRels().build());
        return splitPositiveExamples;
    }

    private void validateTestSplit(GraphStore graphStore) {
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.testRelationshipType()), 1L, "test", "`testFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.testComplementRelationshipType()), 3L, "test-complement", "`testFraction` is too high");
    }

    private void validateTrainSplit(GraphStore graphStore) {
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.trainRelationshipType()), 2L, "train", "`trainFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.featureInputRelationshipType()), 1L, "feature-input", "`trainFraction` is too high");
        NonEmptySetValidation.validateRelSetSize(graphStore.relationshipCount(this.splitConfig.trainRelationshipType()) / this.splitConfig.validationFolds(), 1L, "validation", "`validationFolds` is too high or the `trainFraction` too low");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation splitEstimation(LinkPredictionSplitConfig linkPredictionSplitConfig, String str, Optional<String> optional) {
        RelationshipType of = str.equals("*") ? RelationshipType.ALL_RELATIONSHIPS : RelationshipType.of(str);
        return MemoryEstimations.builder("Split relationships").add(estimatePositiveRelations(of.name, linkPredictionSplitConfig.testFraction(), linkPredictionSplitConfig.trainFraction(), optional)).add(estimateNegativeSampling(of.name, linkPredictionSplitConfig.testFraction(), linkPredictionSplitConfig.trainFraction(), linkPredictionSplitConfig.negativeSamplingRatio(), linkPredictionSplitConfig.negativeRelationshipType())).build();
    }

    private static MemoryEstimation estimatePositiveRelations(String str, double d, double d2, Optional<String> optional) {
        int i = optional.isPresent() ? 24 : 16;
        return MemoryEstimations.builder("Relationship splitter").perGraphDimension("Test and train positive relationships", (graphDimensions, concurrency) -> {
            return MemoryRange.of(((long) (graphDimensions.estimatedRelCount(List.of(str)) * ((d + d2) - (d * d2)))) / 2).times(i);
        }).perGraphDimension("Feature input relationships", (graphDimensions2, concurrency2) -> {
            return MemoryRange.of((long) (graphDimensions2.estimatedRelCount(List.of(str)) * (1.0d - d) * (1.0d - d2))).times(i);
        }).build();
    }

    private static MemoryEstimation estimateNegativeSampling(String str, double d, double d2, double d3, Optional<String> optional) {
        int i = 24;
        return MemoryEstimations.builder("Relationship splitter").perGraphDimension("Negative relationships", (graphDimensions, concurrency) -> {
            return MemoryRange.of(estimateNegativeRelCount(graphDimensions, str, d, d2, d3, optional) / 2).times(i);
        }).build();
    }

    private static long estimateNegativeRelCount(GraphDimensions graphDimensions, String str, double d, double d2, double d3, Optional<String> optional) {
        return optional.isPresent() ? graphDimensions.estimatedRelCount(List.of(optional.get())) : (long) (graphDimensions.estimatedRelCount(List.of(str)) * ((d + d2) - (d * d2)) * d3);
    }
}
