package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.SingleThreadedRandomSeedConfig;
import org.neo4j.gds.core.MissingParameterExceptions;
import org.neo4j.gds.model.ModelConfig;
import org.neo4j.gds.similarity.knn.ImmutableKnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnNodePropertySpec;
import org.neo4j.gds.similarity.knn.KnnSampler;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

@Configuration
/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineBaseConfig.class */
public interface LinkPredictionPredictPipelineBaseConfig extends AlgoBaseConfig, SingleThreadedRandomSeedConfig, ModelConfig {
    public static final double DEFAULT_THRESHOLD = 0.0d;

    String graphName();

    @Value.Default
    @Configuration.DoubleRange(min = DEFAULT_THRESHOLD, max = 1.0d, minInclusive = false)
    default double sampleRate() {
        return 1.0d;
    }

    @Configuration.IntegerRange(min = 1)
    Optional<Integer> topN();

    @Configuration.DoubleRange(min = DEFAULT_THRESHOLD, max = 1.0d)
    Optional<Double> threshold();

    @Configuration.IntegerRange(min = 1)
    Optional<Integer> topK();

    @Configuration.DoubleRange(min = DEFAULT_THRESHOLD, max = 1.0d)
    Optional<Double> deltaThreshold();

    @Configuration.IntegerRange(min = 1)
    Optional<Integer> maxIterations();

    @Configuration.IntegerRange(min = 0)
    Optional<Integer> randomJoins();

    Optional<String> initialSampler();

    @Value.Derived
    @Configuration.Ignore
    default Optional<KnnSampler.SamplerType> derivedInitialSampler() {
        return initialSampler().map((v0) -> {
            return KnnSampler.SamplerType.parse(v0);
        });
    }

    @Value.Check
    default void validateParameterCombinations() {
        if (isApproximateStrategy()) {
            validateStrategySpecificParameters(Map.of("topN", Boolean.valueOf(topN().isPresent()), "threshold", Boolean.valueOf(threshold().isPresent())), "equal to 1");
        } else {
            validateStrategySpecificParameters(Map.of("topK", Boolean.valueOf(topK().isPresent()), "deltaThreshold", Boolean.valueOf(deltaThreshold().isPresent()), "maxIterations", Boolean.valueOf(maxIterations().isPresent()), "randomJoins", Boolean.valueOf(randomJoins().isPresent()), "initialSampler", Boolean.valueOf(derivedInitialSampler().isPresent())), "less than 1");
            topN().orElseThrow(() -> {
                return MissingParameterExceptions.missingValueFor("topN", Collections.emptyList());
            });
        }
    }

    @Configuration.Ignore
    default void validateStrategySpecificParameters(Map<String, Boolean> map, String str) {
        List list = (List) map.entrySet().stream().filter((v0) -> {
            return v0.getValue();
        }).map((v0) -> {
            return v0.getKey();
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Configuration parameters %s may only be set if parameter 'sampleRate' is %s.", new Object[]{StringJoining.join(list), str}));
        }
    }

    @Configuration.Ignore
    @Value.Derived
    default KnnBaseConfig approximateConfig() {
        if (!isApproximateStrategy()) {
            throw new IllegalStateException(StringFormatting.formatWithLocale("Cannot derive approximateConfig when 'sampleRate' is 1.", new Object[0]));
        }
        ImmutableKnnBaseConfig.Builder concurrency = ImmutableKnnBaseConfig.builder().sampleRate(sampleRate()).nodeProperties(List.of(new KnnNodePropertySpec("NotUsedInLP"))).minBatchSize(10).concurrency(concurrency());
        Optional<Integer> pKVar = topK();
        Objects.requireNonNull(concurrency);
        pKVar.ifPresent((v1) -> {
            r1.topK(v1);
        });
        Optional<Double> deltaThreshold = deltaThreshold();
        Objects.requireNonNull(concurrency);
        deltaThreshold.ifPresent((v1) -> {
            r1.deltaThreshold(v1);
        });
        Optional<Integer> maxIterations = maxIterations();
        Objects.requireNonNull(concurrency);
        maxIterations.ifPresent((v1) -> {
            r1.maxIterations(v1);
        });
        Optional<Integer> randomJoins = randomJoins();
        Objects.requireNonNull(concurrency);
        randomJoins.ifPresent((v1) -> {
            r1.randomJoins(v1);
        });
        Optional<KnnSampler.SamplerType> derivedInitialSampler = derivedInitialSampler();
        Objects.requireNonNull(concurrency);
        derivedInitialSampler.ifPresent(concurrency::initialSampler);
        Optional randomSeed = randomSeed();
        Objects.requireNonNull(concurrency);
        randomSeed.ifPresent(concurrency::randomSeed);
        return concurrency.build();
    }

    @Configuration.Ignore
    @Value.Derived
    default double thresholdOrDefault() {
        return threshold().orElse(Double.valueOf(DEFAULT_THRESHOLD)).doubleValue();
    }

    @Configuration.Ignore
    @Value.Derived
    default boolean isApproximateStrategy() {
        return sampleRate() < 1.0d;
    }
}
