package org.neo4j.graphalgo.experimental.community.overlapping;

import org.neo4j.graphalgo.api.Graph;

/* loaded from: input_file:org/neo4j/graphalgo/experimental/community/overlapping/AffiliationBlockGain.class */
public class AffiliationBlockGain implements GainFunction {
    private final int nodeU;
    private final CommunityAffiliations communityAffiliations;
    private final Graph graph;
    private final double deltaSquared;
    private final Vector zero;
    private final Vector affiliationSum;

    public AffiliationBlockGain(CommunityAffiliations communityAffiliations, Graph graph, int i, double d) {
        this.nodeU = i;
        this.graph = graph;
        this.communityAffiliations = communityAffiliations;
        this.deltaSquared = d * d;
        this.zero = Vector.zero(communityAffiliations.affiliationSum().dim());
        this.affiliationSum = communityAffiliations.affiliationSum();
    }

    @Override // org.neo4j.graphalgo.experimental.community.overlapping.GainFunction
    public double gain() {
        return gain(this.zero);
    }

    @Override // org.neo4j.graphalgo.experimental.community.overlapping.GainFunction
    public double gain(Vector vector) {
        Vector add = this.communityAffiliations.nodeAffiliations(this.nodeU).add(vector);
        double[] dArr = {((-add.innerProduct(this.affiliationSum)) - add.innerProduct(vector)) - (this.graph.nodeCount() * this.deltaSquared)};
        dArr[0] = dArr[0] + add.l2Squared() + this.deltaSquared;
        this.graph.concurrentCopy().forEachRelationship(this.nodeU, (j, j2) -> {
            double innerProduct = add.innerProduct(this.communityAffiliations.nodeAffiliations((int) j2)) + this.deltaSquared;
            if (innerProduct < 0.0d) {
                dArr[0] = Double.NEGATIVE_INFINITY;
                return false;
            }
            dArr[0] = dArr[0] + Math.log(1.0d - Math.exp(-innerProduct)) + innerProduct;
            return true;
        });
        return dArr[0] + ((-0.1d) * add.l1());
    }

    @Override // org.neo4j.graphalgo.experimental.community.overlapping.GainFunction
    public Vector gradient() {
        Vector nodeAffiliations = this.communityAffiliations.nodeAffiliations(this.nodeU);
        Vector zero = Vector.zero(this.affiliationSum.dim());
        this.graph.concurrentCopy().forEachRelationship(this.nodeU, (j, j2) -> {
            zero.addInPlace(weightedNeighbor(nodeAffiliations, this.communityAffiliations.nodeAffiliations((int) j2)));
            return true;
        });
        zero.addInPlace(this.communityAffiliations.nodeAffiliations(this.nodeU));
        zero.subtractInPlace(this.communityAffiliations.affiliationSum());
        zero.subtractInPlace(this.communityAffiliations.l1PenaltyGradient());
        return zero;
    }

    private Vector weightedNeighbor(Vector vector, Vector vector2) {
        return vector2.multiply(1.0d / (1.0d - Math.exp((-vector.innerProduct(vector2)) - this.deltaSquared)));
    }
}
