package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import org.neo4j.gds.embeddings.graphsage.RelationshipWeights;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Dimensions;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.gds.embeddings.graphsage.subgraph.SubGraph;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/WeightedMultiMean.class */
public class WeightedMultiMean extends SingleParentVariable<Matrix> {
    private final RelationshipWeights relationshipWeights;
    private final SubGraph subGraph;
    private final int[][] adjacency;
    private final int[] selfAdjacency;
    private final int rows;
    private final int cols;

    public WeightedMultiMean(Variable<Matrix> variable, RelationshipWeights relationshipWeights, SubGraph subGraph) {
        super(variable, Dimensions.matrix(subGraph.adjacency.length, variable.dimension(1)));
        this.relationshipWeights = relationshipWeights;
        this.subGraph = subGraph;
        this.adjacency = subGraph.adjacency;
        this.selfAdjacency = subGraph.selfAdjacency;
        this.rows = this.adjacency.length;
        this.cols = variable.dimension(1);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix apply(ComputationContext computationContext) {
        double[] data = computationContext.data(parent()).data();
        double[] dArr = new double[this.adjacency.length * this.cols];
        for (int i = 0; i < this.adjacency.length; i++) {
            int i2 = this.selfAdjacency[i];
            long j = this.subGraph.nextNodes[i2];
            int i3 = i2 * this.cols;
            int i4 = i * this.cols;
            int[] iArr = this.adjacency[i];
            int length = iArr.length;
            for (int i5 = 0; i5 < this.cols; i5++) {
                int i6 = i4 + i5;
                dArr[i6] = dArr[i6] + (data[i3 + i5] / (length + 1));
            }
            for (int i7 : iArr) {
                int i8 = i7 * this.cols;
                double weight = this.relationshipWeights.weight(j, this.subGraph.nextNodes[i7]);
                for (int i9 = 0; i9 < this.cols; i9++) {
                    int i10 = i4 + i9;
                    dArr[i10] = dArr[i10] + ((data[i8 + i9] * weight) / (length + 1));
                }
            }
        }
        return new Matrix(dArr, this.rows, this.cols);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        double[] data = ((Matrix) computationContext.gradient(this)).data();
        Tensor<?> zeros = computationContext.data(variable).zeros();
        for (int i = 0; i < this.cols; i++) {
            for (int i2 = 0; i2 < this.rows; i2++) {
                int i3 = this.selfAdjacency[i2];
                long j = this.subGraph.nextNodes[i3];
                int length = this.adjacency[i2].length + 1;
                int i4 = (i2 * this.cols) + i;
                for (int i5 : this.adjacency[i2]) {
                    zeros.addDataAt((i5 * this.cols) + i, (1.0d / length) * data[i4] * this.relationshipWeights.weight(j, this.subGraph.nextNodes[i5]));
                }
                zeros.addDataAt((i3 * this.cols) + i, (1.0d / length) * data[i4]);
            }
        }
        return zeros;
    }
}
