package ai.timefold.solver.core.impl.score.stream.collector;

import ai.timefold.solver.core.api.score.stream.common.LoadBalance;
import ai.timefold.solver.core.impl.domain.solution.cloner.gizmo.GizmoSolutionClonerImplementor;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:ai/timefold/solver/core/impl/score/stream/collector/LoadBalanceImpl.class */
public final class LoadBalanceImpl<Balanced_> implements LoadBalance<Balanced_> {
    private static final MathContext RESULT_MATH_CONTEXT = new MathContext(6, RoundingMode.HALF_EVEN);
    private final Map<Balanced_, Integer> balancedItemCountMap = new HashMap();
    private final Map<Balanced_, Long> balancedItemToMetricValueMap = new LinkedHashMap();
    private long sum = 0;
    private long squaredDeviationIntegralPart = 0;
    private long squaredDeviationFractionNumerator = 0;

    public Runnable registerBalanced(Balanced_ balanced_, long j, long j2) {
        if (this.balancedItemCountMap.compute(balanced_, (obj, num) -> {
            return Integer.valueOf(num == null ? 1 : num.intValue() + 1);
        }).intValue() == 1) {
            addToMetric(balanced_, j + j2);
        } else {
            addToMetric(balanced_, j);
        }
        return () -> {
            unregisterBalanced(balanced_, j);
        };
    }

    public void unregisterBalanced(Balanced_ balanced_, long j) {
        if (this.balancedItemCountMap.compute(balanced_, (obj, num) -> {
            if (num.intValue() == 1) {
                return null;
            }
            return Integer.valueOf(num.intValue() - 1);
        }) == null) {
            resetMetric(balanced_);
        } else {
            addToMetric(balanced_, -j);
        }
    }

    private void addToMetric(Balanced_ balanced_, long j) {
        long longValue = this.balancedItemToMetricValueMap.getOrDefault(balanced_, 0L).longValue();
        long j2 = longValue + j;
        this.balancedItemToMetricValueMap.put(balanced_, Long.valueOf(j2));
        if (longValue != j2) {
            updateSquaredDeviation(longValue, j2);
            this.sum += j;
        }
    }

    private void resetMetric(Balanced_ balanced_) {
        long longValue = ((Long) Objects.requireNonNullElse(this.balancedItemToMetricValueMap.remove(balanced_), 0L)).longValue();
        if (longValue != 0) {
            updateSquaredDeviation(longValue, 0L);
            this.sum -= longValue;
        }
    }

    private void updateSquaredDeviation(long j, long j2) {
        long j3 = 2 * (this.sum - j);
        long j4 = (this.sum - j) + j2;
        long j5 = this.sum - j4;
        long j6 = (j4 * j4) - (this.sum * this.sum);
        long j7 = 2 * ((j * this.sum) - (j2 * j4));
        this.squaredDeviationIntegralPart += (j2 * j2) - (j * j);
        this.squaredDeviationFractionNumerator += (j3 * j5) + j6 + j7;
    }

    @Override // ai.timefold.solver.core.api.score.stream.common.LoadBalance
    public Map<Balanced_, Long> loads() {
        return this.balancedItemCountMap.isEmpty() ? Collections.emptyMap() : Collections.unmodifiableMap(this.balancedItemToMetricValueMap);
    }

    @Override // ai.timefold.solver.core.api.score.stream.common.LoadBalance
    public BigDecimal unfairness() {
        int size = this.balancedItemCountMap.size();
        switch (size) {
            case GizmoSolutionClonerImplementor.DEBUG /* 0 */:
                return BigDecimal.ZERO;
            case 1:
                return BigDecimal.valueOf(this.squaredDeviationFractionNumerator + this.squaredDeviationIntegralPart).sqrt(RESULT_MATH_CONTEXT);
            default:
                return BigDecimal.valueOf((this.squaredDeviationFractionNumerator / size) + this.squaredDeviationIntegralPart).sqrt(RESULT_MATH_CONTEXT);
        }
    }
}
