package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.ElementTypeValidator;
import org.neo4j.gds.config.GraphNameConfig;
import org.neo4j.gds.config.RandomSeedConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.ml.metrics.LinkMetric;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.training.TrainBaseConfig;
import org.neo4j.gds.utils.StringFormatting;

@Configuration
/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainConfig.class */
public interface LinkPredictionTrainConfig extends TrainBaseConfig, GraphNameConfig, RandomSeedConfig {
    @Value.Default
    @Configuration.DoubleRange(min = 0.0d, minInclusive = false)
    default double negativeClassWeight() {
        return 1.0d;
    }

    String pipeline();

    String targetRelationshipType();

    default String sourceNodeLabel() {
        return "*";
    }

    default String targetNodeLabel() {
        return "*";
    }

    @Configuration.Ignore
    default List<String> relationshipTypes() {
        return List.of(targetRelationshipType());
    }

    @Value.Check
    default void validate() {
        if (targetRelationshipType().equals("*")) {
            throw new IllegalArgumentException("'*' is not allowed as targetRelationshipType.");
        }
    }

    @Configuration.Ignore
    default RelationshipType internalTargetRelationshipType() {
        return RelationshipType.of(targetRelationshipType());
    }

    @Configuration.Ignore
    default List<String> nodeLabels() {
        return (List) Stream.of((Object[]) new String[]{sourceNodeLabel(), targetNodeLabel()}).distinct().collect(Collectors.toList());
    }

    @Configuration.ConvertWith(method = "org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig#namesToMetrics")
    @Configuration.ToMapValue("org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig#metricsToNames")
    default List<Metric> metrics() {
        return List.of(LinkMetric.AUCPR);
    }

    @Configuration.Ignore
    default Metric mainMetric() {
        return metrics().get(0);
    }

    @Configuration.Ignore
    default List<LinkMetric> linkMetrics() {
        return (List) metrics().stream().filter(metric -> {
            return !metric.isModelSpecific();
        }).map(metric2 -> {
            return (LinkMetric) metric2;
        }).collect(Collectors.toList());
    }

    @Configuration.GraphStoreValidationCheck
    default void validateSourceNodeLabel(GraphStore graphStore, Collection<NodeLabel> collection, Collection<RelationshipType> collection2) {
        ElementTypeValidator.resolveAndValidate(graphStore, List.of(sourceNodeLabel()), "sourceNodeLabel");
    }

    @Configuration.GraphStoreValidationCheck
    default void validateTargetNodeLabel(GraphStore graphStore, Collection<NodeLabel> collection, Collection<RelationshipType> collection2) {
        ElementTypeValidator.resolveAndValidate(graphStore, List.of(targetNodeLabel()), "sourceNodeLabel");
    }

    @Configuration.GraphStoreValidationCheck
    default void validateTargetRelIsUndirected(GraphStore graphStore, Collection<NodeLabel> collection, Collection<RelationshipType> collection2) {
        if (!graphStore.schema().filterRelationshipTypes(Set.of(internalTargetRelationshipType())).isUndirected()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Target relationship type `%s` must be UNDIRECTED, but was directed.", new Object[]{targetRelationshipType()}));
        }
    }

    static LinkPredictionTrainConfig of(String str, CypherMapWrapper cypherMapWrapper) {
        return new LinkPredictionTrainConfigImpl(str, cypherMapWrapper);
    }

    static List<Metric> namesToMetrics(List<?> list) {
        return (List) list.stream().map(LinkMetric::parseLinkMetric).collect(Collectors.toList());
    }

    static List<String> metricsToNames(List<Metric> list) {
        return (List) list.stream().map((v0) -> {
            return v0.name();
        }).collect(Collectors.toList());
    }
}
