/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LPNodeFilter;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.similarity.knn.NeighborFilter;
import org.neo4j.gds.similarity.knn.NeighborFilterFactory;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;

class LinkPredictionSimilarityComputer
implements SimilarityComputer {
    private static final int POSITIVE_CLASS_INDEX = 1;
    private final LinkFeatureExtractor linkFeatureExtractor;
    private final Classifier classifier;

    LinkPredictionSimilarityComputer(LinkFeatureExtractor linkFeatureExtractor, Classifier classifier) {
        this.linkFeatureExtractor = linkFeatureExtractor;
        this.classifier = classifier;
    }

    public double similarity(long sourceId, long targetId) {
        double[] features = this.linkFeatureExtractor.extractFeatures(sourceId, targetId);
        return this.classifier.predictProbabilities(features)[1];
    }

    public boolean isSymmetric() {
        return this.linkFeatureExtractor.isSymmetric();
    }

    static class LinkFilterFactory
    implements NeighborFilterFactory {
        private final Graph graph;
        private final LPNodeFilter sourceNodeFilter;
        private final LPNodeFilter targetNodeFilter;

        LinkFilterFactory(Graph graph, LPNodeFilter sourceNodeFilter, LPNodeFilter targetNodeFilter) {
            this.graph = graph;
            this.sourceNodeFilter = sourceNodeFilter;
            this.targetNodeFilter = targetNodeFilter;
        }

        public NeighborFilter create() {
            return new LinkFilter(this.graph.concurrentCopy(), this.sourceNodeFilter, this.targetNodeFilter);
        }
    }

    static final class LinkFilter
    implements NeighborFilter {
        private final LPNodeFilter sourceNodeFilter;
        private final LPNodeFilter targetNodeFilter;
        private final Graph graph;

        private LinkFilter(Graph graph, LPNodeFilter sourceNodeFilter, LPNodeFilter targetNodeFilter) {
            this.graph = graph;
            this.sourceNodeFilter = sourceNodeFilter;
            this.targetNodeFilter = targetNodeFilter;
        }

        public boolean excludeNodePair(long firstNodeId, long secondNodeId) {
            if (firstNodeId == secondNodeId) {
                return true;
            }
            boolean matchesFilter = this.sourceNodeFilter.test(firstNodeId) && this.targetNodeFilter.test(secondNodeId) || this.sourceNodeFilter.test(secondNodeId) && this.targetNodeFilter.test(firstNodeId);
            return !matchesFilter || this.graph.exists(firstNodeId, secondNodeId);
        }

        public long lowerBoundOfPotentialNeighbours(long node) {
            if (this.sourceNodeFilter.test(node)) {
                return Math.max(this.targetNodeFilter.validNodeCount() - 1L - (long)this.graph.degree(node), 0L);
            }
            return Math.max(this.sourceNodeFilter.validNodeCount() - 1L - (long)this.graph.degree(node), 0L);
        }
    }
}

