package org.neo4j.gds.embeddings.node2vec;

import java.util.Objects;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.node2vec.Node2VecModel;
import org.neo4j.gds.embeddings.node2vec.RandomWalkProbabilities;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.traversal.RandomWalk;

/* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2Vec.class */
public class Node2Vec extends Algorithm<Node2VecModel.Result> {
    private final Graph graph;
    private final Node2VecBaseConfig config;

    public static MemoryEstimation memoryEstimation(Node2VecBaseConfig node2VecBaseConfig) {
        return MemoryEstimations.builder(Node2Vec.class.getSimpleName()).perNode("random walks", j -> {
            return HugeObjectArray.memoryEstimation(j * node2VecBaseConfig.walksPerNode(), MemoryUsage.sizeOfLongArray(node2VecBaseConfig.walkLength()));
        }).add("probability cache", RandomWalkProbabilities.memoryEstimation()).add("model", Node2VecModel.memoryEstimation(node2VecBaseConfig)).build();
    }

    public Node2Vec(Graph graph, Node2VecBaseConfig node2VecBaseConfig, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = node2VecBaseConfig;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public Node2VecModel.Result m22compute() {
        this.progressTracker.beginSubTask("Node2Vec");
        RandomWalk create = RandomWalk.create(this.graph, this.config, this.progressTracker, Pools.DEFAULT);
        RandomWalkProbabilities.Builder builder = new RandomWalkProbabilities.Builder(this.graph.nodeCount(), this.config.positiveSamplingFactor(), this.config.negativeSamplingExponent(), this.config.concurrency());
        CompressedRandomWalks compressedRandomWalks = new CompressedRandomWalks(this.graph.nodeCount() * this.config.walksPerNode());
        create.m103compute().forEach(jArr -> {
            builder.registerWalk(jArr);
            compressedRandomWalks.add(jArr);
        });
        Graph graph = this.graph;
        Objects.requireNonNull(graph);
        Node2VecModel.Result train = new Node2VecModel(graph::toOriginalNodeId, this.graph.nodeCount(), this.config, compressedRandomWalks, builder.build(), this.progressTracker).train();
        this.progressTracker.endSubTask("Node2Vec");
        return train;
    }

    public void release() {
    }
}
