package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/NormalizeRows.class */
public class NormalizeRows extends SingleParentVariable<Matrix> {
    private final int rows;
    private final int cols;

    public NormalizeRows(Variable<Matrix> variable) {
        super(variable, variable.dimensions());
        this.rows = dimension(0);
        this.cols = dimension(1);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix apply(ComputationContext computationContext) {
        double[] data = computationContext.data(parent()).data();
        int i = this.rows;
        int i2 = this.cols;
        double[] dArr = new double[i * i2];
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < i2; i4++) {
                d += Math.pow(data[(i3 * i2) + i4], 2.0d);
            }
            double sqrt = Math.sqrt(d);
            for (int i5 = 0; i5 < i2; i5++) {
                int i6 = (i3 * i2) + i5;
                dArr[i6] = data[i6] / sqrt;
            }
        }
        return new Matrix(dArr, this.rows, this.cols);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix gradient(Variable<?> variable, ComputationContext computationContext) {
        double[] data = computationContext.data(variable).data();
        double[] data2 = ((Matrix) computationContext.gradient(this)).data();
        double[] dArr = new double[data.length];
        int i = this.rows;
        int i2 = this.cols;
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < i2; i4++) {
                int i5 = (i3 * i2) + i4;
                d += data[i5] * data[i5];
            }
            double sqrt = Math.sqrt(d) * d;
            for (int i6 = 0; i6 < i2; i6++) {
                int i7 = (i3 * i2) + i6;
                for (int i8 = 0; i8 < i2; i8++) {
                    if (i6 == i8) {
                        dArr[i7] = dArr[i7] + ((data2[i7] * (d - (data[i7] * data[i7]))) / sqrt);
                    } else {
                        dArr[i7] = dArr[i7] - ((data2[(i3 * i2) + i8] * (data[i7] * data[(i3 * i2) + i8])) / sqrt);
                    }
                }
            }
        }
        return new Matrix(dArr, this.rows, this.cols);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public /* bridge */ /* synthetic */ Tensor gradient(Variable variable, ComputationContext computationContext) {
        return gradient((Variable<?>) variable, computationContext);
    }
}
