/*
 * Decompiled with CFR 0.152.
 */
package ai.timefold.solver.core.impl.score.stream.collector;

import ai.timefold.solver.core.api.score.stream.common.LoadBalance;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public final class LoadBalanceImpl<Balanced_>
implements LoadBalance<Balanced_> {
    private static final MathContext RESULT_MATH_CONTEXT = new MathContext(6, RoundingMode.HALF_EVEN);
    private final Map<Balanced_, Integer> balancedItemCountMap = new HashMap<Balanced_, Integer>();
    private final Map<Balanced_, Long> balancedItemToMetricValueMap = new LinkedHashMap<Balanced_, Long>();
    private long sum = 0L;
    private long squaredDeviationIntegralPart = 0L;
    private long squaredDeviationFractionNumerator = 0L;

    public Runnable registerBalanced(Balanced_ balanced, long metricValue, long initialMetricValue) {
        Integer balancedItemCount = this.balancedItemCountMap.compute(balanced, (k, v) -> v == null ? 1 : v + 1);
        if (balancedItemCount == 1) {
            this.addToMetric(balanced, metricValue + initialMetricValue);
        } else {
            this.addToMetric(balanced, metricValue);
        }
        return () -> this.unregisterBalanced(balanced, metricValue);
    }

    public void unregisterBalanced(Balanced_ balanced, long metricValue) {
        Integer count = this.balancedItemCountMap.compute(balanced, (k, v) -> v == 1 ? null : Integer.valueOf(v - 1));
        if (count == null) {
            this.resetMetric(balanced);
        } else {
            this.addToMetric(balanced, -metricValue);
        }
    }

    private void addToMetric(Balanced_ balanced, long diff) {
        long oldValue = this.balancedItemToMetricValueMap.getOrDefault(balanced, 0L);
        long newValue = oldValue + diff;
        this.balancedItemToMetricValueMap.put(balanced, newValue);
        if (oldValue != newValue) {
            this.updateSquaredDeviation(oldValue, newValue);
            this.sum += diff;
        }
    }

    private void resetMetric(Balanced_ balanced) {
        long oldValue = Objects.requireNonNullElse(this.balancedItemToMetricValueMap.remove(balanced), 0L);
        if (oldValue != 0L) {
            this.updateSquaredDeviation(oldValue, 0L);
            this.sum -= oldValue;
        }
    }

    private void updateSquaredDeviation(long oldValue, long newValue) {
        long squaredDeviationFirstTerm = newValue * newValue - oldValue * oldValue;
        long secondTermFirstFactor = 2L * (this.sum - oldValue);
        long newSum = this.sum - oldValue + newValue;
        long secondTermSecondFactor = this.sum - newSum;
        long thirdTerm = newSum * newSum - this.sum * this.sum;
        long fourthTerm = 2L * (oldValue * this.sum - newValue * newSum);
        long squaredDeviationSecondTermNumerator = secondTermFirstFactor * secondTermSecondFactor + thirdTerm + fourthTerm;
        this.squaredDeviationIntegralPart += squaredDeviationFirstTerm;
        this.squaredDeviationFractionNumerator += squaredDeviationSecondTermNumerator;
    }

    @Override
    public Map<Balanced_, Long> loads() {
        if (this.balancedItemCountMap.isEmpty()) {
            return Collections.emptyMap();
        }
        return Collections.unmodifiableMap(this.balancedItemToMetricValueMap);
    }

    @Override
    public BigDecimal unfairness() {
        int totalToBalanceCount = this.balancedItemCountMap.size();
        return switch (totalToBalanceCount) {
            case 0 -> BigDecimal.ZERO;
            case 1 -> BigDecimal.valueOf(this.squaredDeviationFractionNumerator + this.squaredDeviationIntegralPart).sqrt(RESULT_MATH_CONTEXT);
            default -> {
                double tmp = (double)this.squaredDeviationFractionNumerator / (double)totalToBalanceCount + (double)this.squaredDeviationIntegralPart;
                yield BigDecimal.valueOf(tmp).sqrt(RESULT_MATH_CONTEXT);
            }
        };
    }
}

