/*
 * Decompiled with CFR 0.152.
 */
package com.datastax.dse.driver.internal.core.loadbalancing;

import com.datastax.dse.driver.internal.core.tracker.MultiplexingRequestTracker;
import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
import com.datastax.oss.driver.api.core.config.DriverOption;
import com.datastax.oss.driver.api.core.context.DriverContext;
import com.datastax.oss.driver.api.core.loadbalancing.LoadBalancingPolicy;
import com.datastax.oss.driver.api.core.loadbalancing.NodeDistance;
import com.datastax.oss.driver.api.core.metadata.Node;
import com.datastax.oss.driver.api.core.metadata.NodeState;
import com.datastax.oss.driver.api.core.metadata.TokenMap;
import com.datastax.oss.driver.api.core.metadata.token.Token;
import com.datastax.oss.driver.api.core.session.Request;
import com.datastax.oss.driver.api.core.session.Session;
import com.datastax.oss.driver.api.core.tracker.RequestTracker;
import com.datastax.oss.driver.internal.core.context.InternalDriverContext;
import com.datastax.oss.driver.internal.core.metadata.MetadataManager;
import com.datastax.oss.driver.internal.core.pool.ChannelPool;
import com.datastax.oss.driver.internal.core.session.DefaultSession;
import com.datastax.oss.driver.internal.core.util.ArrayUtils;
import com.datastax.oss.driver.internal.core.util.Reflection;
import com.datastax.oss.driver.internal.core.util.collection.QueryPlan;
import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap;
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.nio.ByteBuffer;
import java.util.BitSet;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.function.IntUnaryOperator;
import java.util.function.Predicate;
import net.jcip.annotations.ThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ThreadSafe
public class DseLoadBalancingPolicy
implements LoadBalancingPolicy,
RequestTracker {
    private static final Logger LOG = LoggerFactory.getLogger(DseLoadBalancingPolicy.class);
    private static final Predicate<Node> INCLUDE_ALL_NODES = n -> true;
    private static final IntUnaryOperator INCREMENT = i -> i == Integer.MAX_VALUE ? 0 : i + 1;
    private static final long NEWLY_UP_INTERVAL_NANOS = TimeUnit.MINUTES.toNanos(1L);
    private static final int MAX_IN_FLIGHT_THRESHOLD = 10;
    private static final long RESPONSE_COUNT_RESET_INTERVAL_NANOS = TimeUnit.MILLISECONDS.toNanos(200L);
    @NonNull
    private final String logPrefix;
    @NonNull
    private final MetadataManager metadataManager;
    @NonNull
    private final Predicate<Node> filter;
    private final boolean isDefaultPolicy;
    @Nullable
    @VisibleForTesting
    volatile String localDc;
    @NonNull
    private volatile LoadBalancingPolicy.DistanceReporter distanceReporter = (node, distance) -> {};
    private final AtomicInteger roundRobinAmount = new AtomicInteger();
    @VisibleForTesting
    final CopyOnWriteArraySet<Node> localDcLiveNodes = new CopyOnWriteArraySet();
    @VisibleForTesting
    final Map<Node, AtomicLongArray> responseTimes = new ConcurrentHashMap<Node, AtomicLongArray>();
    @VisibleForTesting
    final Map<Node, Long> upTimes = new ConcurrentHashMap<Node, Long>();

    public DseLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String profileName) {
        this.logPrefix = context.getSessionName() + "|" + profileName;
        this.metadataManager = ((InternalDriverContext)context).getMetadataManager();
        this.isDefaultPolicy = profileName.equals("default");
        this.localDc = this.getLocalDcFromConfig((InternalDriverContext)context, profileName);
        Predicate<Node> filterFromConfig = DseLoadBalancingPolicy.getFilterFromConfig(context, profileName);
        this.filter = node -> {
            String localDc = this.localDc;
            if (localDc != null && !localDc.equals(node.getDatacenter())) {
                LOG.debug("[{}] Ignoring {} because it doesn't belong to the local DC {}", new Object[]{this.logPrefix, node, localDc});
                return false;
            }
            if (!filterFromConfig.test((Node)node)) {
                LOG.debug("[{}] Ignoring {} because it doesn't match the user-provided predicate", (Object)this.logPrefix, node);
                return false;
            }
            return true;
        };
        ((MultiplexingRequestTracker)context.getRequestTracker()).register(this);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public void init(@NonNull Map<UUID, Node> nodes, @NonNull LoadBalancingPolicy.DistanceReporter distanceReporter) {
        this.distanceReporter = distanceReporter;
        Set contactPoints = this.metadataManager.getContactPoints();
        if (this.localDc == null) {
            if (!this.metadataManager.wasImplicitContactPoint()) throw new IllegalStateException("You provided explicit contact points, the local DC must be specified (see " + DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER.getPath() + " in the config)");
            assert (contactPoints.size() == 1);
            Node contactPoint = (Node)contactPoints.iterator().next();
            this.localDc = contactPoint.getDatacenter();
            LOG.debug("[{}] Local DC set from contact point {}: {}", new Object[]{this.logPrefix, contactPoint, this.localDc});
        } else {
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Node node : contactPoints) {
                String datacenter = node.getDatacenter();
                if (Objects.equals(this.localDc, datacenter)) continue;
                builder.put((Object)node, (Object)(datacenter == null ? "<null>" : datacenter));
            }
            ImmutableMap badContactPoints = builder.build();
            if (this.isDefaultPolicy && !badContactPoints.isEmpty()) {
                LOG.warn("[{}] You specified {} as the local DC, but some contact points are from a different DC ({})", new Object[]{this.logPrefix, this.localDc, badContactPoints});
            }
        }
        for (Node node : nodes.values()) {
            if (this.filter.test(node)) {
                distanceReporter.setDistance(node, NodeDistance.LOCAL);
                if (node.getState() == NodeState.DOWN) continue;
                this.localDcLiveNodes.add(node);
                continue;
            }
            distanceReporter.setDistance(node, NodeDistance.IGNORED);
        }
    }

    @NonNull
    public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session session) {
        Object[] currentNodes = this.localDcLiveNodes.toArray();
        Set<Node> allReplicas = this.getReplicas(request, session);
        int replicaCount = 0;
        if (!allReplicas.isEmpty()) {
            for (int i = 0; i < currentNodes.length; ++i) {
                Node node = (Node)currentNodes[i];
                if (!allReplicas.contains(node)) continue;
                ArrayUtils.bubbleUp((Object[])currentNodes, (int)i, (int)replicaCount);
                ++replicaCount;
            }
            if (replicaCount > 1) {
                this.shuffleHead(currentNodes, replicaCount);
                if (replicaCount > 2) {
                    int unhealthyReplicasCount;
                    assert (session != null);
                    Node newestUpReplica = null;
                    BitSet unhealthyReplicas = null;
                    long mostRecentUpTimeNanos = -1L;
                    long now = this.nanoTime();
                    for (int i = 0; i < replicaCount; ++i) {
                        Node node = (Node)currentNodes[i];
                        Long upTimeNanos = this.upTimes.get(node);
                        if (upTimeNanos != null && now - upTimeNanos - NEWLY_UP_INTERVAL_NANOS < 0L && upTimeNanos - mostRecentUpTimeNanos > 0L) {
                            newestUpReplica = node;
                            mostRecentUpTimeNanos = upTimeNanos;
                        }
                        if (newestUpReplica != null || !this.isUnhealthy(node, session, now)) continue;
                        if (unhealthyReplicas == null) {
                            unhealthyReplicas = new BitSet(replicaCount);
                        }
                        unhealthyReplicas.set(i);
                    }
                    int n = unhealthyReplicasCount = unhealthyReplicas == null ? 0 : unhealthyReplicas.cardinality();
                    if (newestUpReplica == null && unhealthyReplicasCount > 0 && (double)unhealthyReplicasCount < (double)replicaCount / 2.0) {
                        int counter = 0;
                        for (int i = replicaCount - 1; i >= 0 && counter < unhealthyReplicasCount; --i) {
                            if (!unhealthyReplicas.get(i)) continue;
                            ArrayUtils.bubbleDown((Object[])currentNodes, (int)i, (int)(replicaCount - 1 - counter));
                            ++counter;
                        }
                    } else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1]) && this.diceRoll1d4() != 1) {
                        ArrayUtils.bubbleDown((Object[])currentNodes, (int)(newestUpReplica == currentNodes[0] ? 0 : 1), (int)(replicaCount - 1));
                    }
                    if (DseLoadBalancingPolicy.getInFlight((Node)currentNodes[0], session) > DseLoadBalancingPolicy.getInFlight((Node)currentNodes[1], session)) {
                        ArrayUtils.swap((Object[])currentNodes, (int)0, (int)1);
                    }
                }
            }
        }
        LOG.trace("[{}] Prioritizing {} local replicas", (Object)this.logPrefix, (Object)replicaCount);
        ArrayUtils.rotate((Object[])currentNodes, (int)replicaCount, (int)(currentNodes.length - replicaCount), (int)this.roundRobinAmount.getAndUpdate(INCREMENT));
        return new QueryPlan(currentNodes);
    }

    public void onAdd(@NonNull Node node) {
        if (this.filter.test(node)) {
            LOG.debug("[{}] {} was added, setting distance to LOCAL", (Object)this.logPrefix, (Object)node);
            this.distanceReporter.setDistance(node, NodeDistance.LOCAL);
        } else {
            this.distanceReporter.setDistance(node, NodeDistance.IGNORED);
        }
    }

    public void onUp(@NonNull Node node) {
        if (this.filter.test(node)) {
            this.distanceReporter.setDistance(node, NodeDistance.LOCAL);
            if (this.localDcLiveNodes.add(node)) {
                this.upTimes.put(node, this.nanoTime());
                LOG.debug("[{}] {} came back UP, added to live set", (Object)this.logPrefix, (Object)node);
            }
        } else {
            this.distanceReporter.setDistance(node, NodeDistance.IGNORED);
        }
    }

    public void onDown(@NonNull Node node) {
        if (this.localDcLiveNodes.remove(node)) {
            this.upTimes.remove(node);
            LOG.debug("[{}] {} went DOWN, removed from live set", (Object)this.logPrefix, (Object)node);
        }
    }

    public void onRemove(@NonNull Node node) {
        if (this.localDcLiveNodes.remove(node)) {
            this.upTimes.remove(node);
            LOG.debug("[{}] {} was removed, removed from live set", (Object)this.logPrefix, (Object)node);
        }
    }

    public void onNodeSuccess(@NonNull Request request, long latencyNanos, @NonNull DriverExecutionProfile executionProfile, @NonNull Node node, @NonNull String logPrefix) {
        this.updateResponseTimes(node);
    }

    public void onNodeError(@NonNull Request request, @NonNull Throwable error, long latencyNanos, @NonNull DriverExecutionProfile executionProfile, @NonNull Node node, @NonNull String logPrefix) {
        this.updateResponseTimes(node);
    }

    public void close() {
    }

    @VisibleForTesting
    void shuffleHead(Object[] array, int n) {
        ArrayUtils.shuffleHead((Object[])array, (int)n);
    }

    @VisibleForTesting
    long nanoTime() {
        return System.nanoTime();
    }

    @VisibleForTesting
    int diceRoll1d4() {
        return ThreadLocalRandom.current().nextInt(4);
    }

    private Set<Node> getReplicas(@Nullable Request request, @Nullable Session session) {
        ByteBuffer key;
        if (request == null || session == null) {
            return Collections.emptySet();
        }
        CqlIdentifier keyspace = request.getKeyspace();
        if (keyspace == null) {
            keyspace = request.getRoutingKeyspace();
        }
        if (keyspace == null && session.getKeyspace().isPresent()) {
            keyspace = (CqlIdentifier)session.getKeyspace().get();
        }
        if (keyspace == null) {
            return Collections.emptySet();
        }
        Token token = request.getRoutingToken();
        ByteBuffer byteBuffer = key = token == null ? request.getRoutingKey() : null;
        if (token == null && key == null) {
            return Collections.emptySet();
        }
        Optional maybeTokenMap = this.metadataManager.getMetadata().getTokenMap();
        if (maybeTokenMap.isPresent()) {
            TokenMap tokenMap = (TokenMap)maybeTokenMap.get();
            return token != null ? tokenMap.getReplicas(keyspace, token) : tokenMap.getReplicas(keyspace, key);
        }
        return Collections.emptySet();
    }

    private boolean isUnhealthy(@NonNull Node node, @NonNull Session session, long now) {
        return this.isBusy(node, session) && this.isResponseRateInsufficient(node, now);
    }

    private boolean isBusy(@NonNull Node node, @NonNull Session session) {
        return DseLoadBalancingPolicy.getInFlight(node, session) >= 10;
    }

    @VisibleForTesting
    boolean isResponseRateInsufficient(@NonNull Node node, long now) {
        AtomicLongArray array;
        if (this.responseTimes.containsKey(node) && (array = this.responseTimes.get(node)).length() == 2) {
            long threshold = now - RESPONSE_COUNT_RESET_INTERVAL_NANOS;
            long leastRecent = array.get(0);
            return leastRecent - threshold < 0L;
        }
        return true;
    }

    private void updateResponseTimes(@NonNull Node node) {
        this.responseTimes.compute(node, (n, array) -> {
            long now = this.nanoTime();
            if (array == null) {
                array = new AtomicLongArray(1);
                array.set(0, now);
            } else if (array.length() == 1) {
                long previous = array.get(0);
                array = new AtomicLongArray(2);
                array.set(0, previous);
                array.set(1, now);
            } else {
                array.set(0, array.get(1));
                array.set(1, now);
            }
            return array;
        });
    }

    private String getLocalDcFromConfig(@NonNull InternalDriverContext context, @NonNull String profileName) {
        String localDataCenter = context.getLocalDatacenter(profileName);
        if (localDataCenter != null) {
            LOG.debug("[{}] Local DC set from builder: {}", (Object)this.logPrefix, (Object)localDataCenter);
            return localDataCenter;
        }
        DriverExecutionProfile config = context.getConfig().getProfile(profileName);
        localDataCenter = config.getString((DriverOption)DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER, null);
        if (localDataCenter != null) {
            LOG.debug("[{}] Local DC set from configuration: {}", (Object)this.logPrefix, (Object)localDataCenter);
        }
        return localDataCenter;
    }

    private static int getInFlight(@NonNull Node node, @NonNull Session session) {
        ChannelPool pool = (ChannelPool)((DefaultSession)session).getPools().get(node);
        return pool == null ? 0 : pool.getInFlight();
    }

    private static Predicate<Node> getFilterFromConfig(@NonNull DriverContext context, @NonNull String profileName) {
        Predicate filterFromBuilder = ((InternalDriverContext)context).getNodeFilter(profileName);
        if (filterFromBuilder != null) {
            return filterFromBuilder;
        }
        Predicate<Node> filter = Reflection.buildFromConfig((InternalDriverContext)((InternalDriverContext)context), (String)profileName, (DriverOption)DefaultDriverOption.LOAD_BALANCING_FILTER_CLASS, Predicate.class, (String[])new String[0]).orElse(INCLUDE_ALL_NODES);
        return filter;
    }
}

