package org.neo4j.gds.embeddings.graphsage;

import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;
import org.neo4j.graphalgo.core.utils.partition.Partition;
import org.neo4j.graphalgo.core.utils.partition.PartitionUtils;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageEmbeddingsGenerator.class */
public class GraphSageEmbeddingsGenerator {
    private final Layer[] layers;
    private final int batchSize;
    private final int concurrency;
    private final boolean isWeighted;
    private final FeatureFunction featureFunction;
    private final ExecutorService executor;
    private final ProgressLogger progressLogger;
    private final AllocationTracker tracker;

    public GraphSageEmbeddingsGenerator(Layer[] layerArr, int i, int i2, boolean z, FeatureFunction featureFunction, ExecutorService executorService, ProgressLogger progressLogger, AllocationTracker allocationTracker) {
        this.layers = layerArr;
        this.batchSize = i;
        this.concurrency = i2;
        this.isWeighted = z;
        this.featureFunction = featureFunction;
        this.executor = executorService;
        this.progressLogger = progressLogger;
        this.tracker = allocationTracker;
    }

    public HugeObjectArray<double[]> makeEmbeddings(Graph graph, HugeObjectArray<double[]> hugeObjectArray) {
        HugeObjectArray<double[]> newArray = HugeObjectArray.newArray(double[].class, graph.nodeCount(), this.tracker);
        this.progressLogger.logStart();
        ParallelUtil.run(PartitionUtils.rangePartition(this.concurrency, graph.nodeCount(), this.batchSize, partition -> {
            return createEmbeddings(graph, partition, hugeObjectArray, newArray);
        }), this.executor);
        this.progressLogger.logFinish();
        return newArray;
    }

    private Runnable createEmbeddings(Graph graph, Partition partition, HugeObjectArray<double[]> hugeObjectArray, HugeObjectArray<double[]> hugeObjectArray2) {
        return () -> {
            ComputationContext computationContext = new ComputationContext();
            Variable<Matrix> embeddings = GraphSageHelper.embeddings(graph, this.isWeighted, partition.stream().toArray(), hugeObjectArray, this.layers, this.featureFunction);
            int dimension = embeddings.dimension(1);
            double[] data = computationContext.forward(embeddings).data();
            long startNode = partition.startNode();
            long nodeCount = partition.nodeCount();
            for (int i = 0; i < nodeCount; i++) {
                hugeObjectArray2.set(i + startNode, Arrays.copyOfRange(data, i * dimension, (i + 1) * dimension));
            }
            this.progressLogger.logProgress();
        };
    }
}
