package org.neo4j.graphalgo.experimental.community.overlapping;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import org.neo4j.graphalgo.api.Graph;

/* loaded from: input_file:org/neo4j/graphalgo/experimental/community/overlapping/CommunityAffiliations.class */
public class CommunityAffiliations {
    private final long totalDoubleEdgeCount;
    private final List<Vector> affiliationVectors;
    private final Graph graph;
    private final Vector affiliationSum;
    static final double LAMBDA = 0.1d;
    private final Vector l1PenaltyGradient;

    /* JADX INFO: Access modifiers changed from: package-private */
    public CommunityAffiliations(long j, List<Vector> list, Graph graph) {
        this.totalDoubleEdgeCount = j;
        this.affiliationVectors = list;
        this.graph = graph;
        this.affiliationSum = Vector.sum(list);
        this.l1PenaltyGradient = Vector.l1PenaltyGradient(this.affiliationSum.dim(), LAMBDA);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GainFunction blockGain(int i, double d) {
        return new AffiliationBlockGain(this, this.graph, i, d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Vector affiliationSum() {
        return this.affiliationSum;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Vector l1PenaltyGradient() {
        return this.l1PenaltyGradient;
    }

    public Vector nodeAffiliations(int i) {
        return this.affiliationVectors.get(i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public synchronized void updateNodeAffiliations(int i, Vector vector) {
        Vector addAndProject = this.affiliationVectors.get(i).addAndProject(vector);
        Vector subtract = addAndProject.subtract(this.affiliationVectors.get(i));
        this.affiliationVectors.set(i, addAndProject);
        this.affiliationSum.addInPlace(subtract);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double gain() {
        double delta = getDelta();
        double d = delta * delta;
        double[] dArr = {(-this.affiliationSum.l2Squared()) - ((this.graph.nodeCount() * delta) * (this.graph.nodeCount() * delta))};
        double[] dArr2 = {0.0d};
        for (int i = 0; i < this.graph.nodeCount(); i++) {
            Vector nodeAffiliations = nodeAffiliations(i);
            dArr2[0] = dArr2[0] + nodeAffiliations.l1();
            dArr[0] = dArr[0] + nodeAffiliations.l2Squared() + d;
            this.graph.forEachRelationship(i, (j, j2) -> {
                if (j < j2) {
                    return true;
                }
                double innerProduct = nodeAffiliations.innerProduct(nodeAffiliations((int) j2)) + d;
                dArr[0] = dArr[0] + (2.0d * (Math.log(1.0d - Math.exp(-innerProduct)) + innerProduct));
                return true;
            });
        }
        return dArr[0] - (LAMBDA * dArr2[0]);
    }

    public long nodeCount() {
        return this.graph.nodeCount();
    }

    private double getEpsilon() {
        return BigDecimal.valueOf(this.totalDoubleEdgeCount).divide(BigDecimal.valueOf(this.graph.nodeCount() * (this.graph.nodeCount() - 1)), 12, RoundingMode.FLOOR).doubleValue();
    }

    public double getDelta() {
        return Math.sqrt(-Math.log(1.0d - getEpsilon()));
    }
}
