package org.neo4j.gds.applications.algorithms.machinelearning;

import com.carrotsearch.hppc.BitSet;
import org.neo4j.gds.algorithms.machinelearning.KGEPredictBaseConfig;
import org.neo4j.gds.algorithms.machinelearning.KGEPredictParameters;
import org.neo4j.gds.algorithms.machinelearning.KGEPredictResult;
import org.neo4j.gds.algorithms.machinelearning.TopKMapComputer;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery;
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.ml.splitting.SplitRelationships;
import org.neo4j.gds.ml.splitting.SplitRelationshipsBaseConfig;
import org.neo4j.gds.similarity.filtering.NodeFilter;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/applications/algorithms/machinelearning/MachineLearningAlgorithms.class */
class MachineLearningAlgorithms {
    private final AlgorithmMachinery algorithmMachinery = new AlgorithmMachinery();
    private final ProgressTrackerCreator progressTrackerCreator;
    private final TerminationFlag terminationFlag;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MachineLearningAlgorithms(ProgressTrackerCreator progressTrackerCreator, TerminationFlag terminationFlag) {
        this.progressTrackerCreator = progressTrackerCreator;
        this.terminationFlag = terminationFlag;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public KGEPredictResult kge(Graph graph, KGEPredictBaseConfig kGEPredictBaseConfig) {
        ProgressTracker createProgressTracker = this.progressTrackerCreator.createProgressTracker(kGEPredictBaseConfig, Tasks.leaf(AlgorithmLabel.KGE.asString()));
        BitSet bitSet = new BitSet(graph.nodeCount());
        BitSet bitSet2 = new BitSet(graph.nodeCount());
        KGEPredictParameters parameters = kGEPredictBaseConfig.toParameters();
        NodeFilter nodeFilter = parameters.sourceNodeFilter().toNodeFilter(graph);
        NodeFilter nodeFilter2 = parameters.targetNodeFilter().toNodeFilter(graph);
        graph.forEachNode(j -> {
            if (nodeFilter.test(j)) {
                bitSet.set(j);
            }
            if (!nodeFilter2.test(j)) {
                return true;
            }
            bitSet2.set(j);
            return true;
        });
        return (KGEPredictResult) this.algorithmMachinery.runAlgorithmsAndManageProgressTracker(new TopKMapComputer(graph, bitSet, bitSet2, parameters.nodeEmbeddingProperty(), parameters.relationshipTypeEmbedding(), parameters.scoringFunction(), parameters.topK(), parameters.concurrency(), createProgressTracker, this.terminationFlag), createProgressTracker, true);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public EdgeSplitter.SplitResult splitRelationships(GraphStore graphStore, SplitRelationshipsBaseConfig splitRelationshipsBaseConfig) {
        return SplitRelationships.of(graphStore, splitRelationshipsBaseConfig).compute();
    }
}
