/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.gds.ElementIdentifier;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.schema.Direction;
import org.neo4j.gds.config.ElementTypeValidator;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ImmutableLPGraphStoreFilter;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LPGraphStoreFilter;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

public final class LPGraphStoreFilterFactory {
    private LPGraphStoreFilterFactory() {
    }

    public static LPGraphStoreFilter generate(LinkPredictionTrainConfig trainConfig, LinkPredictionPredictPipelineBaseConfig predictConfig, GraphStore graphStore, ProgressTracker progressTracker) {
        Collection sourceNodeLabels = predictConfig.sourceNodeLabel().map(label -> ElementTypeValidator.resolve((GraphStore)graphStore, List.of(label))).orElse(ElementTypeValidator.resolveAndValidate((GraphStore)graphStore, List.of(trainConfig.sourceNodeLabel()), (String)"`sourceNodeLabel` from the model's train config"));
        Collection targetNodeLabels = predictConfig.targetNodeLabel().map(label -> ElementTypeValidator.resolve((GraphStore)graphStore, List.of(label))).orElse(ElementTypeValidator.resolveAndValidate((GraphStore)graphStore, List.of(trainConfig.targetNodeLabel()), (String)"`targetNodeLabel` from the model's train config"));
        Collection predictRelTypes = !predictConfig.relationshipTypes().isEmpty() ? ElementTypeValidator.resolveAndValidateTypes((GraphStore)graphStore, predictConfig.relationshipTypes(), (String)"`relationshipTypes` from the model's predict config") : ElementTypeValidator.resolveAndValidateTypes((GraphStore)graphStore, List.of(trainConfig.targetRelationshipType()), (String)"`targetRelationshipType` from the model's train config");
        LPGraphStoreFilterFactory.validateGraphFilter(graphStore, predictRelTypes);
        Set<NodeLabel> nodePropertyStepsBaseLabels = Stream.of(targetNodeLabels, sourceNodeLabels).flatMap(Collection::stream).collect(Collectors.toSet());
        LPGraphStoreFilter filter = ImmutableLPGraphStoreFilter.builder().sourceNodeLabels(sourceNodeLabels).targetNodeLabels(targetNodeLabels).nodePropertyStepsBaseLabels(nodePropertyStepsBaseLabels).predictRelationshipTypes(predictRelTypes).build();
        progressTracker.logInfo(StringFormatting.formatWithLocale((String)"The graph filters used for filtering in prediction is %s", (Object[])new Object[]{filter}));
        return filter;
    }

    static void validateGraphFilter(GraphStore graphStore, Collection<RelationshipType> predictedRelationships) {
        List directedPredictRels = graphStore.schema().filterRelationshipTypes(new HashSet<RelationshipType>(predictedRelationships)).relationshipSchema().directions().entrySet().stream().filter(entry -> entry.getValue() != Direction.UNDIRECTED).map(Map.Entry::getKey).map(ElementIdentifier::name).collect(Collectors.toList());
        if (!directedPredictRels.isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Procedure requires all relationships of %s to be UNDIRECTED, but found %s to be directed.", (Object[])new Object[]{StringJoining.join(predictedRelationships.stream().map(ElementIdentifier::name)), StringJoining.join(directedPredictRels)}));
        }
    }
}

