/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.d2.balancer.strategies.relative;

import com.linkedin.d2.D2RelativeStrategyProperties;
import com.linkedin.d2.balancer.clients.TrackerClient;
import com.linkedin.d2.balancer.strategies.DelegatingRingFactory;
import com.linkedin.d2.balancer.strategies.PartitionStateUpdateListener;
import com.linkedin.d2.balancer.strategies.relative.PartitionState;
import com.linkedin.d2.balancer.strategies.relative.QuarantineManager;
import com.linkedin.d2.balancer.strategies.relative.TrackerClientState;
import com.linkedin.d2.balancer.util.hashing.Ring;
import com.linkedin.util.degrader.CallTracker;
import com.linkedin.util.degrader.ErrorType;
import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StateUpdater {
    private static final Logger LOG = LoggerFactory.getLogger(StateUpdater.class);
    public static final double MIN_HEALTH_SCORE = 0.0;
    public static final double MAX_HEALTH_SCORE = 1.0;
    private static final double SLOW_START_INITIAL_HEALTH_SCORE = 0.01;
    private static final int SLOW_START_RECOVERY_FACTOR = 2;
    private static final int LOG_UNHEALTHY_CLIENT_NUMBERS = 10;
    private static final long EXECUTOR_INITIAL_DELAY = 10L;
    private final D2RelativeStrategyProperties _relativeStrategyProperties;
    private final QuarantineManager _quarantineManager;
    private final ScheduledExecutorService _executorService;
    private final Lock _lock;
    private final List<PartitionStateUpdateListener.Factory<PartitionState>> _listenerFactories;
    private final String _serviceName;
    private ConcurrentMap<Integer, PartitionState> _partitionLoadBalancerStateMap;
    private int _firstPartitionId = -1;

    StateUpdater(D2RelativeStrategyProperties relativeStrategyProperties, QuarantineManager quarantineManager, ScheduledExecutorService executorService, List<PartitionStateUpdateListener.Factory<PartitionState>> listenerFactories, String serviceName) {
        this(relativeStrategyProperties, quarantineManager, executorService, new ConcurrentHashMap<Integer, PartitionState>(), listenerFactories, serviceName);
    }

    StateUpdater(D2RelativeStrategyProperties relativeStrategyProperties, QuarantineManager quarantineManager, ScheduledExecutorService executorService, ConcurrentMap<Integer, PartitionState> partitionLoadBalancerStateMap, List<PartitionStateUpdateListener.Factory<PartitionState>> listenerFactories, String serviceName) {
        this._relativeStrategyProperties = relativeStrategyProperties;
        this._quarantineManager = quarantineManager;
        this._executorService = executorService;
        this._listenerFactories = listenerFactories;
        this._partitionLoadBalancerStateMap = partitionLoadBalancerStateMap;
        this._lock = new ReentrantLock();
        this._serviceName = serviceName;
        this._executorService.scheduleWithFixedDelay(this::updateState, 10L, this._relativeStrategyProperties.getUpdateIntervalMs(), TimeUnit.MILLISECONDS);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void updateState(Set<TrackerClient> trackerClients, int partitionId, long clusterGenerationId, boolean shouldForceUpdate) {
        if (!this._partitionLoadBalancerStateMap.containsKey(partitionId)) {
            this._lock.lock();
            try {
                this.initializePartition(trackerClients, partitionId, clusterGenerationId);
            }
            finally {
                this._lock.unlock();
            }
        } else if (shouldForceUpdate || clusterGenerationId != ((PartitionState)this._partitionLoadBalancerStateMap.get(partitionId)).getClusterGenerationId() || trackerClients.size() != ((PartitionState)this._partitionLoadBalancerStateMap.get(partitionId)).getPointsMap().size()) {
            this._executorService.execute(() -> this.updateStateDueToClusterChange(trackerClients, partitionId, clusterGenerationId, shouldForceUpdate));
        }
    }

    Ring<URI> getRing(int partitionId) {
        return ((PartitionState)this._partitionLoadBalancerStateMap.get(partitionId)).getRing();
    }

    Map<URI, Integer> getPointsMap(int partitionId) {
        return this._partitionLoadBalancerStateMap.get(partitionId) == null ? new HashMap<URI, Integer>() : ((PartitionState)this._partitionLoadBalancerStateMap.get(partitionId)).getPointsMap();
    }

    PartitionState getPartitionState(int partitionId) {
        return (PartitionState)this._partitionLoadBalancerStateMap.get(partitionId);
    }

    int getFirstValidPartitionId() {
        return this._firstPartitionId;
    }

    void updateState() {
        try {
            for (Integer partitionId : this._partitionLoadBalancerStateMap.keySet()) {
                PartitionState partitionState = (PartitionState)this._partitionLoadBalancerStateMap.get(partitionId);
                this.updateStateForPartition(partitionState.getTrackerClients(), partitionId, partitionState, partitionState.getClusterGenerationId(), false);
            }
        }
        catch (Exception ex) {
            LOG.error("Failed to update the state for service: " + this._serviceName, (Throwable)ex);
        }
    }

    void updateStateForPartition(Set<TrackerClient> trackerClients, int partitionId, PartitionState oldPartitionState, Long clusterGenerationId, boolean shouldForceUpdate) {
        LOG.debug("Updating for partition: " + partitionId + ", state: " + oldPartitionState);
        PartitionState newPartitionState = new PartitionState(oldPartitionState);
        HashMap<TrackerClient, CallTracker.CallStats> latestCallStatsMap = new HashMap<TrackerClient, CallTracker.CallStats>();
        long avgClusterLatency = this.getAvgClusterLatency(trackerClients, latestCallStatsMap);
        boolean clusterUpdated = shouldForceUpdate || clusterGenerationId.longValue() != oldPartitionState.getClusterGenerationId();
        this.updateBaseHealthScoreAndState(trackerClients, newPartitionState, avgClusterLatency, clusterUpdated, latestCallStatsMap);
        this._quarantineManager.updateQuarantineState(newPartitionState, oldPartitionState, avgClusterLatency);
        newPartitionState.updateRing();
        newPartitionState.setClusterGenerationId(clusterGenerationId);
        this._partitionLoadBalancerStateMap.put(partitionId, newPartitionState);
        this._executorService.execute(() -> {
            this.logState(oldPartitionState, newPartitionState, partitionId);
            this.notifyPartitionStateUpdateListener(newPartitionState);
        });
    }

    void updateStateDueToClusterChange(Set<TrackerClient> trackerClients, int partitionId, Long newClusterGenerationId, boolean shouldForceUpdate) {
        if (shouldForceUpdate || newClusterGenerationId.longValue() != ((PartitionState)this._partitionLoadBalancerStateMap.get(partitionId)).getClusterGenerationId() || trackerClients.size() != ((PartitionState)this._partitionLoadBalancerStateMap.get(partitionId)).getPointsMap().size()) {
            PartitionState oldPartitionState = (PartitionState)this._partitionLoadBalancerStateMap.get(partitionId);
            this.updateStateForPartition(trackerClients, partitionId, oldPartitionState, newClusterGenerationId, shouldForceUpdate);
        }
    }

    private void updateBaseHealthScoreAndState(Set<TrackerClient> trackerClients, PartitionState partitionState, long clusterAvgLatency, boolean clusterUpdated, Map<TrackerClient, CallTracker.CallStats> lastCallStatsMap) {
        this.calculateBaseHealthScore(trackerClients, partitionState, clusterAvgLatency, lastCallStatsMap);
        Map<TrackerClient, TrackerClientState> trackerClientStateMap = partitionState.getTrackerClientStateMap();
        if (clusterUpdated) {
            List trackerClientsToRemove = trackerClientStateMap.keySet().stream().filter(oldTrackerClient -> !trackerClients.contains(oldTrackerClient)).collect(Collectors.toList());
            for (TrackerClient trackerClient : trackerClientsToRemove) {
                partitionState.removeTrackerClient(trackerClient);
            }
        }
    }

    private void calculateBaseHealthScore(Set<TrackerClient> trackerClients, PartitionState partitionState, long avgClusterLatency, Map<TrackerClient, CallTracker.CallStats> lastCallStatsMap) {
        Map<TrackerClient, TrackerClientState> trackerClientStateMap = partitionState.getTrackerClientStateMap();
        long clusterCallCount = 0L;
        long clusterErrorCount = 0L;
        for (TrackerClient trackerClient : trackerClients) {
            CallTracker.CallStats latestCallStats = lastCallStatsMap.get(trackerClient);
            if (trackerClientStateMap.containsKey(trackerClient)) {
                double oldHealthScore;
                TrackerClientState trackerClientState = trackerClientStateMap.get(trackerClient);
                int callCount = latestCallStats.getCallCount() + latestCallStats.getOutstandingCount();
                if (trackerClient.doNotLoadBalance()) {
                    trackerClientState.setHealthState(TrackerClientState.HealthState.HEALTHY);
                    trackerClientState.setHealthScore(1.0);
                    trackerClientState.setCallCount(callCount);
                    continue;
                }
                double errorRate = StateUpdater.getErrorRate(latestCallStats.getErrorTypeCounts(), callCount);
                long avgLatency = StateUpdater.getAvgHostLatency(latestCallStats);
                double newHealthScore = oldHealthScore = trackerClientState.getHealthScore();
                clusterCallCount += (long)callCount;
                clusterErrorCount = (long)((double)clusterErrorCount + errorRate * (double)callCount);
                if (this.isUnhealthy(trackerClientState, avgClusterLatency, callCount, avgLatency, errorRate)) {
                    newHealthScore = Double.max(trackerClientState.getHealthScore() - this._relativeStrategyProperties.getDownStep(), 0.0);
                    trackerClientState.setHealthState(TrackerClientState.HealthState.UNHEALTHY);
                    LOG.debug("Host is unhealthy. Host: " + trackerClient.toString() + ", errorRate: " + errorRate + ", latency: " + avgClusterLatency + ", callCount: " + callCount + ", healthScore dropped from " + trackerClientState.getHealthScore() + " to " + newHealthScore);
                } else if (trackerClientState.getHealthScore() < 1.0 && this.isHealthy(trackerClientState, avgClusterLatency, callCount, avgLatency, errorRate)) {
                    newHealthScore = oldHealthScore < this._relativeStrategyProperties.getSlowStartThreshold() ? (oldHealthScore > 0.0 ? Math.min(1.0, 2.0 * oldHealthScore) : 0.01) : Math.min(1.0, oldHealthScore + this._relativeStrategyProperties.getUpStep());
                    trackerClientState.setHealthState(TrackerClientState.HealthState.HEALTHY);
                } else {
                    trackerClientState.setHealthState(TrackerClientState.HealthState.NEUTRAL);
                }
                trackerClientState.setHealthScore(newHealthScore);
                trackerClientState.setCallCount(callCount);
                continue;
            }
            if (trackerClient.doNotSlowStart() || trackerClient.doNotLoadBalance()) {
                trackerClientStateMap.put(trackerClient, new TrackerClientState(1.0, this._relativeStrategyProperties.getMinCallCount()));
                continue;
            }
            trackerClientStateMap.put(trackerClient, new TrackerClientState(this._relativeStrategyProperties.getInitialHealthScore(), this._relativeStrategyProperties.getMinCallCount()));
        }
        partitionState.setPartitionStats(avgClusterLatency, clusterCallCount, clusterErrorCount);
    }

    private long getAvgClusterLatency(Set<TrackerClient> trackerClients, Map<TrackerClient, CallTracker.CallStats> latestCallStatsMap) {
        long latencySum = 0L;
        long outstandingLatencySum = 0L;
        int callCountSum = 0;
        int outstandingCallCountSum = 0;
        for (TrackerClient trackerClient : trackerClients) {
            CallTracker.CallStats latestCallStats = trackerClient.getCallTracker().getCallStats();
            latestCallStatsMap.put(trackerClient, latestCallStats);
            if (trackerClient.doNotLoadBalance()) continue;
            int callCount = latestCallStats.getCallCount();
            int outstandingCallCount = latestCallStats.getOutstandingCount();
            latencySum = (long)((double)latencySum + latestCallStats.getCallTimeStats().getAverage() * (double)callCount);
            outstandingLatencySum += latestCallStats.getOutstandingStartTimeAvg() * (long)outstandingCallCount;
            callCountSum += callCount;
            outstandingCallCountSum += outstandingCallCount;
        }
        return callCountSum + outstandingCallCountSum == 0 ? 0L : (long)Math.ceil((double)(latencySum + outstandingLatencySum) / (double)(callCountSum + outstandingCallCountSum));
    }

    public static long getAvgHostLatency(CallTracker.CallStats callStats) {
        int outstandingCallCount;
        double avgLatency = callStats.getCallTimeStats().getAverage();
        long avgOutstandingLatency = callStats.getOutstandingStartTimeAvg();
        int callCount = callStats.getCallCount();
        return callCount + (outstandingCallCount = callStats.getOutstandingCount()) == 0 ? 0L : Math.round(avgLatency * ((double)callCount / (double)(callCount + outstandingCallCount)) + (double)avgOutstandingLatency * ((double)outstandingCallCount / (double)(callCount + outstandingCallCount)));
    }

    private boolean isUnhealthy(TrackerClientState trackerClientState, long avgClusterLatency, int callCount, long latency, double errorRate) {
        return callCount >= trackerClientState.getAdjustedMinCallCount() && ((double)latency >= (double)avgClusterLatency * this._relativeStrategyProperties.getRelativeLatencyHighThresholdFactor() || errorRate >= this._relativeStrategyProperties.getHighErrorRate());
    }

    private boolean isHealthy(TrackerClientState trackerClientState, long avgClusterLatency, int callCount, long latency, double errorRate) {
        return callCount >= trackerClientState.getAdjustedMinCallCount() && (double)latency <= (double)avgClusterLatency * this._relativeStrategyProperties.getRelativeLatencyLowThresholdFactor() && errorRate <= this._relativeStrategyProperties.getLowErrorRate();
    }

    private void notifyPartitionStateUpdateListener(PartitionState state) {
        state.getListeners().forEach(listener -> listener.onUpdate(state));
    }

    private static double getErrorRate(Map<ErrorType, Integer> errorTypeCounts, int callCount) {
        Integer connectExceptionCount = errorTypeCounts.getOrDefault(ErrorType.CONNECT_EXCEPTION, 0);
        Integer closedChannelExceptionCount = errorTypeCounts.getOrDefault(ErrorType.CLOSED_CHANNEL_EXCEPTION, 0);
        Integer serverErrorCount = errorTypeCounts.getOrDefault(ErrorType.SERVER_ERROR, 0);
        Integer timeoutExceptionCount = errorTypeCounts.getOrDefault(ErrorType.TIMEOUT_EXCEPTION, 0);
        return callCount == 0 ? 0.0 : (double)(connectExceptionCount + closedChannelExceptionCount + serverErrorCount + timeoutExceptionCount) / (double)callCount;
    }

    private void initializePartition(Set<TrackerClient> trackerClients, int partitionId, long clusterGenerationId) {
        if (!this._partitionLoadBalancerStateMap.containsKey(partitionId)) {
            PartitionState partitionState = new PartitionState(partitionId, new DelegatingRingFactory<URI>(this._relativeStrategyProperties.getRingProperties()), this._relativeStrategyProperties.getRingProperties().getPointsPerWeight(), this._listenerFactories.stream().map(factory -> factory.create(partitionId)).collect(Collectors.toList()));
            this.updateStateForPartition(trackerClients, partitionId, partitionState, clusterGenerationId, false);
            if (this._firstPartitionId < 0) {
                this._firstPartitionId = partitionId;
            }
        }
    }

    private void logState(PartitionState oldState, PartitionState newState, int partitionId) {
        Map<TrackerClient, TrackerClientState> newTrackerClientStateMap = newState.getTrackerClientStateMap();
        Map<TrackerClient, TrackerClientState> oldTrackerClientStateMap = oldState.getTrackerClientStateMap();
        Set<TrackerClient> newUnhealthyClients = newTrackerClientStateMap.keySet().stream().filter(trackerClient -> ((TrackerClientState)newTrackerClientStateMap.get(trackerClient)).getHealthScore() < 1.0).collect(Collectors.toSet());
        Set<TrackerClient> oldUnhealthyClients = oldTrackerClientStateMap.keySet().stream().filter(trackerClient -> ((TrackerClientState)oldTrackerClientStateMap.get(trackerClient)).getHealthScore() < 1.0).collect(Collectors.toSet());
        if (LOG.isDebugEnabled()) {
            LOG.debug("Strategy updated: service=" + this._serviceName + ", partitionId=" + partitionId + ", unhealthyClientNumber=" + newUnhealthyClients.size() + ", newState=" + newState + ", unhealthyClients={" + newUnhealthyClients.stream().limit(10L).map(client -> StateUpdater.getClientStats(client, newTrackerClientStateMap)).collect(Collectors.joining(",")) + (newUnhealthyClients.size() > 10 ? "...(total " + newUnhealthyClients.size() + ")" : "") + "},, oldState=" + oldState);
        } else if (StateUpdater.allowToLog(oldState, newState, newUnhealthyClients, oldUnhealthyClients)) {
            LOG.info("Strategy updated: service=" + this._serviceName + ", partitionId=" + partitionId + ", unhealthyClientNumber=" + newUnhealthyClients.size() + ", newState=" + newState + ", unhealthyClients={" + newUnhealthyClients.stream().limit(10L).map(client -> StateUpdater.getClientStats(client, newTrackerClientStateMap)).collect(Collectors.joining(",")) + (newUnhealthyClients.size() > 10 ? "...(total " + newUnhealthyClients.size() + ")" : "") + "},, oldState=" + oldState);
        }
    }

    private static boolean allowToLog(PartitionState oldState, PartitionState newState, Set<TrackerClient> newUnhealthyClients, Set<TrackerClient> oldUnhealthyClients) {
        for (URI uri : newState.getPointsMap().keySet()) {
            if (oldState.getPointsMap().containsKey(uri)) continue;
            return true;
        }
        for (TrackerClient client : newUnhealthyClients) {
            if (oldUnhealthyClients.contains(client)) continue;
            return true;
        }
        for (TrackerClient trackerClient : newState.getRecoveryTrackerClients()) {
            if (oldState.getRecoveryTrackerClients().contains(trackerClient)) continue;
            return true;
        }
        for (TrackerClient trackerClient : newState.getQuarantineMap().keySet()) {
            if (oldState.getQuarantineMap().containsKey(trackerClient)) continue;
            return true;
        }
        return false;
    }

    private static String getClientStats(TrackerClient client, Map<TrackerClient, TrackerClientState> trackerClientStateMap) {
        return client.getUri() + ":" + trackerClientStateMap.get(client).getHealthScore();
    }
}

