package org.neo4j.gds.ml.pipeline.linkPipeline;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.ml.pipeline.NonEmptySetValidation;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

@Configuration
/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionSplitConfig.class */
public interface LinkPredictionSplitConfig extends ToMapConvertible {
    public static final String TEST_FRACTION_KEY = "testFraction";
    public static final String TRAIN_FRACTION_KEY = "trainFraction";
    public static final LinkPredictionSplitConfig DEFAULT_CONFIG = of(CypherMapWrapper.empty());

    @Configuration.IntegerRange(min = NonEmptySetValidation.MIN_TRAIN_SET_SIZE)
    @Value.Default
    default int validationFolds() {
        return 3;
    }

    @Value.Default
    @Configuration.Key(TEST_FRACTION_KEY)
    @Configuration.DoubleRange(min = 0.0d, minInclusive = false)
    default double testFraction() {
        return 0.1d;
    }

    @Value.Default
    @Configuration.Key(TRAIN_FRACTION_KEY)
    @Configuration.DoubleRange(min = 0.0d, minInclusive = false)
    default double trainFraction() {
        return 0.1d;
    }

    @Value.Default
    @Configuration.DoubleRange(min = 0.0d, minInclusive = false)
    default double negativeSamplingRatio() {
        return 1.0d;
    }

    Optional<String> negativeRelationshipType();

    @Configuration.Ignore
    @Value.Default
    default RelationshipType testRelationshipType() {
        return RelationshipType.of("_TEST_");
    }

    @Configuration.Ignore
    @Value.Default
    default RelationshipType testComplementRelationshipType() {
        return RelationshipType.of("_TEST_COMPLEMENT_");
    }

    @Configuration.Ignore
    @Value.Default
    default RelationshipType trainRelationshipType() {
        return RelationshipType.of("_TRAIN_");
    }

    @Configuration.Ignore
    @Value.Default
    default RelationshipType featureInputRelationshipType() {
        return RelationshipType.of("_FEATURE_INPUT_");
    }

    @Configuration.ToMap
    Map<String, Object> toMap();

    @Configuration.CollectKeys
    default Collection<String> configKeys() {
        return Collections.emptyList();
    }

    static LinkPredictionSplitConfig of(CypherMapWrapper cypherMapWrapper) {
        return new LinkPredictionSplitConfigImpl(cypherMapWrapper);
    }

    @Configuration.Ignore
    default void validateAgainstGraphStore(GraphStore graphStore, RelationshipType relationshipType) {
        Stream of = Stream.of((Object[]) new RelationshipType[]{testRelationshipType(), trainRelationshipType(), featureInputRelationshipType(), testComplementRelationshipType()});
        Objects.requireNonNull(graphStore);
        List list = (List) of.filter(graphStore::hasRelationshipType).map((v0) -> {
            return v0.name();
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("The relationship types %s are in the input graph, but are reserved for splitting.", new Object[]{StringJoining.join(list)}));
        }
        ExpectedSetSizes expectedSetSizes = expectedSetSizes(graphStore.relationshipCount(relationshipType));
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.testSize(), 1L, "test", "`testFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.testComplementSize(), 3L, "test-complement", "`testFraction` is too high");
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.trainSize(), 2L, "train", "`trainFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.featureInputSize(), 1L, "feature-input", "`trainFraction` is too high");
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.validationFoldSize(), 1L, "validation", "`validationFolds` is too high or the `trainFraction` too low");
    }

    @Value.Derived
    @Configuration.Ignore
    default ExpectedSetSizes expectedSetSizes(long j) {
        long testFraction = (long) (((long) ((j * testFraction()) / 2.0d)) * (1.0d + negativeSamplingRatio()));
        long testFraction2 = (long) (j * (1.0d - testFraction()));
        long trainFraction = (long) (((long) ((testFraction2 * trainFraction()) / 2.0d)) * (1.0d + negativeSamplingRatio()));
        return ImmutableExpectedSetSizes.builder().testSize(testFraction).trainSize(trainFraction).featureInputSize((long) (testFraction2 * (1.0d - trainFraction()))).testComplementSize(testFraction2).validationFoldSize(trainFraction / validationFolds()).build();
    }

    @Value.Derived
    @Configuration.Ignore
    default GraphDimensions expectedGraphDimensions(GraphDimensions graphDimensions, String str) {
        ExpectedSetSizes expectedSetSizes = expectedSetSizes(((Long) graphDimensions.relationshipCounts().getOrDefault(RelationshipType.of(str), Long.valueOf(graphDimensions.relCountUpperBound()))).longValue());
        return GraphDimensions.builder().nodeCount(graphDimensions.nodeCount()).relCountUpperBound(graphDimensions.relCountUpperBound()).putRelationshipCount(testRelationshipType(), expectedSetSizes.testSize()).putRelationshipCount(testComplementRelationshipType(), expectedSetSizes.testComplementSize()).putRelationshipCount(trainRelationshipType(), expectedSetSizes.trainSize()).putRelationshipCount(featureInputRelationshipType(), expectedSetSizes.featureInputSize()).putAllRelationshipCounts(graphDimensions.relationshipCounts()).build();
    }
}
