package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import java.util.List;
import org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Dimensions;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/LogisticLoss.class */
public class LogisticLoss extends AbstractVariable<Scalar> {
    private Variable<Matrix> weights;
    private Variable<Matrix> predictions;
    private Variable<Matrix> features;
    private Variable<Matrix> targets;

    LogisticLoss(Variable<Matrix> variable, Variable<Matrix> variable2, Variable<Matrix> variable3, Variable<Matrix> variable4) {
        super(List.of(variable, variable3, variable4), Dimensions.scalar());
        this.weights = variable;
        this.predictions = variable2;
        this.features = variable3;
        this.targets = variable4;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Scalar apply(ComputationContext computationContext) {
        double d;
        double d2;
        computationContext.forward(this.predictions);
        Matrix matrix = (Matrix) computationContext.data(this.predictions);
        Matrix matrix2 = (Matrix) computationContext.data(this.targets);
        double d3 = 0.0d;
        for (int i = 0; i < matrix.rows(); i++) {
            double dataAt = matrix2.dataAt(i) * Math.log(matrix.dataAt(i));
            double dataAt2 = (1.0d - matrix2.dataAt(i)) * Math.log(1.0d - matrix.dataAt(i));
            if (matrix.dataAt(i) == 0.0d) {
                d = d3;
                d2 = dataAt2;
            } else if (matrix.dataAt(i) == 1.0d) {
                d = d3;
                d2 = dataAt;
            } else {
                d = d3;
                d2 = dataAt + dataAt2;
            }
            d3 = d + d2;
        }
        return new Scalar((-d3) / matrix.rows());
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        if (variable != this.weights) {
            return computationContext.data(variable).zeros();
        }
        computationContext.forward(this.predictions);
        Matrix matrix = (Matrix) computationContext.data(this.predictions);
        Matrix matrix2 = (Matrix) computationContext.data(this.targets);
        Matrix matrix3 = (Matrix) computationContext.data(this.weights);
        Matrix matrix4 = (Matrix) computationContext.data(this.features);
        Matrix zeros = matrix3.zeros();
        int cols = matrix3.cols();
        int rows = matrix.rows();
        for (int i = 0; i < rows; i++) {
            double dataAt = (matrix.dataAt(i) - matrix2.dataAt(i)) / rows;
            for (int i2 = 0; i2 < cols; i2++) {
                zeros.addDataAt(i2, dataAt * matrix4.dataAt((i * cols) + i2));
            }
        }
        return zeros;
    }
}
