package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import java.util.List;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.mult.MatrixMatrixMult_DDRM;
import org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Dimensions;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/MatrixMultiplyWithTransposedSecondOperand.class */
public class MatrixMultiplyWithTransposedSecondOperand extends AbstractVariable<Matrix> {
    private final Variable<Matrix> A;
    private final Variable<Matrix> B;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MatrixMultiplyWithTransposedSecondOperand(Variable<Matrix> variable, Variable<Matrix> variable2) {
        super(List.of(variable, variable2), Dimensions.matrix(variable.dimension(0), variable2.dimension(0)));
        assertDimensions(variable, variable2);
        this.A = variable;
        this.B = variable2;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix apply(ComputationContext computationContext) {
        return multiplyTransB(computationContext.data(this.A), computationContext.data(this.B));
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix gradient(Variable<?> variable, ComputationContext computationContext) {
        Tensor<?> gradient = computationContext.gradient(this);
        return variable == this.A ? multiply(gradient, computationContext.data(this.B)) : multiplyTransA(gradient, computationContext.data(this.A));
    }

    private Matrix multiply(Tensor<?> tensor, Tensor<?> tensor2) {
        DMatrixRMaj wrap = DMatrixRMaj.wrap(tensor.dimension(0), tensor.dimension(1), tensor.data());
        DMatrixRMaj wrap2 = DMatrixRMaj.wrap(tensor2.dimension(0), tensor2.dimension(1), tensor2.data());
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(wrap.numRows, wrap2.numCols);
        MatrixMatrixMult_DDRM.mult_reorder(wrap, wrap2, dMatrixRMaj);
        return new Matrix(dMatrixRMaj.getData(), dMatrixRMaj.numRows, dMatrixRMaj.numCols);
    }

    private Matrix multiplyTransB(Tensor<?> tensor, Tensor<?> tensor2) {
        DMatrixRMaj wrap = DMatrixRMaj.wrap(tensor.dimension(0), tensor.dimension(1), tensor.data());
        DMatrixRMaj wrap2 = DMatrixRMaj.wrap(tensor2.dimension(0), tensor2.dimension(1), tensor2.data());
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(wrap.numRows, wrap2.numRows);
        MatrixMatrixMult_DDRM.multTransB(wrap, wrap2, dMatrixRMaj);
        return new Matrix(dMatrixRMaj.getData(), dMatrixRMaj.numRows, dMatrixRMaj.numCols);
    }

    private Matrix multiplyTransA(Tensor<?> tensor, Tensor<?> tensor2) {
        DMatrixRMaj wrap = DMatrixRMaj.wrap(tensor.dimension(0), tensor.dimension(1), tensor.data());
        DMatrixRMaj wrap2 = DMatrixRMaj.wrap(tensor2.dimension(0), tensor2.dimension(1), tensor2.data());
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(wrap.numCols, wrap2.numCols);
        MatrixMatrixMult_DDRM.multTransA_reorder(wrap, wrap2, dMatrixRMaj);
        return new Matrix(dMatrixRMaj.getData(), dMatrixRMaj.numRows, dMatrixRMaj.numCols);
    }

    public static MatrixMultiplyWithTransposedSecondOperand of(Variable<Matrix> variable, Variable<Matrix> variable2) {
        return new MatrixMultiplyWithTransposedSecondOperand(variable, variable2);
    }

    private void assertDimensions(Variable<Matrix> variable, Variable<Matrix> variable2) {
        if (!$assertionsDisabled && variable.dimension(1) != variable2.dimension(1)) {
            throw new AssertionError(StringFormatting.formatWithLocale("Cannot multiply matrix having dimensions (%d, %d) with transposed matrix of dimensions (%d, %d)", new Object[]{Integer.valueOf(variable.dimension(1)), Integer.valueOf(variable.dimension(0)), Integer.valueOf(variable2.dimension(0)), Integer.valueOf(variable2.dimension(1))}));
        }
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public /* bridge */ /* synthetic */ Tensor gradient(Variable variable, ComputationContext computationContext) {
        return gradient((Variable<?>) variable, computationContext);
    }

    static {
        $assertionsDisabled = !MatrixMultiplyWithTransposedSecondOperand.class.desiredAssertionStatus();
    }
}
