package org.neo4j.gds.embeddings.graphsage;

import java.util.Arrays;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageEmbeddingsGenerator.class */
public class GraphSageEmbeddingsGenerator {
    private final Layer[] layers;
    private final BatchProvider batchProvider;
    private final int concurrency;
    private boolean isWeighted;
    private final FeatureFunction featureFunction;
    private final ProgressLogger progressLogger;
    private final AllocationTracker tracker;

    public GraphSageEmbeddingsGenerator(Layer[] layerArr, int i, int i2, boolean z, ProgressLogger progressLogger, AllocationTracker allocationTracker) {
        this(layerArr, i, i2, z, GraphSageHelper::features, progressLogger, allocationTracker);
    }

    public GraphSageEmbeddingsGenerator(Layer[] layerArr, int i, int i2, boolean z, FeatureFunction featureFunction, ProgressLogger progressLogger, AllocationTracker allocationTracker) {
        this.layers = layerArr;
        this.batchProvider = new BatchProvider(i);
        this.concurrency = i2;
        this.isWeighted = z;
        this.featureFunction = featureFunction;
        this.progressLogger = progressLogger;
        this.tracker = allocationTracker;
    }

    public HugeObjectArray<double[]> makeEmbeddings(Graph graph, HugeObjectArray<double[]> hugeObjectArray) {
        HugeObjectArray<double[]> newArray = HugeObjectArray.newArray(double[].class, graph.nodeCount(), this.tracker);
        this.progressLogger.logStart();
        ParallelUtil.parallelStreamConsume(this.batchProvider.stream(graph), this.concurrency, stream -> {
            stream.forEach(jArr -> {
                ComputationContext computationContext = new ComputationContext();
                Variable<Matrix> embeddings = GraphSageHelper.embeddings(graph, this.isWeighted, jArr, hugeObjectArray, this.layers, this.featureFunction);
                int dimension = embeddings.dimension(1);
                double[] data = computationContext.forward(embeddings).data();
                for (int i = 0; i < jArr.length; i++) {
                    newArray.set(jArr[i], Arrays.copyOfRange(data, i * dimension, (i + 1) * dimension));
                }
                this.progressLogger.logProgress();
            });
        });
        this.progressLogger.logFinish();
        return newArray;
    }
}
