/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.d2.balancer.util.hashing;

import com.linkedin.d2.balancer.Facilities;
import com.linkedin.d2.balancer.ServiceUnavailableException;
import com.linkedin.d2.balancer.URIMapper;
import com.linkedin.d2.balancer.util.LoadBalancerUtil;
import com.linkedin.d2.balancer.util.URIKeyPair;
import com.linkedin.d2.balancer.util.URIMappingResult;
import com.linkedin.d2.balancer.util.URIRequest;
import com.linkedin.d2.balancer.util.hashing.HashFunction;
import com.linkedin.d2.balancer.util.hashing.HashRingProvider;
import com.linkedin.d2.balancer.util.hashing.RandomHash;
import com.linkedin.d2.balancer.util.hashing.Ring;
import com.linkedin.d2.balancer.util.hashing.URIRegexHash;
import com.linkedin.d2.balancer.util.partitions.PartitionAccessException;
import com.linkedin.d2.balancer.util.partitions.PartitionAccessor;
import com.linkedin.d2.balancer.util.partitions.PartitionInfoProvider;
import com.linkedin.r2.message.Request;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RingBasedUriMapper
implements URIMapper {
    private static final Logger LOG = LoggerFactory.getLogger(RingBasedUriMapper.class);
    private static final int PARTITION_NOT_FOUND_ID = -1;
    private final HashRingProvider _hashRingProvider;
    private final PartitionInfoProvider _partitionInfoProvider;

    public RingBasedUriMapper(HashRingProvider hashRingProvider, PartitionInfoProvider partitionInfoProvider) {
        this._hashRingProvider = hashRingProvider;
        this._partitionInfoProvider = partitionInfoProvider;
    }

    public RingBasedUriMapper(Facilities facilities) {
        this(facilities.getHashRingProvider(), facilities.getPartitionInfoProvider());
    }

    @Override
    public <KEY> URIMappingResult<KEY> mapUris(List<URIKeyPair<KEY>> requestUriKeyPairs) throws ServiceUnavailableException {
        if (requestUriKeyPairs == null || requestUriKeyPairs.isEmpty()) {
            return new URIMappingResult(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
        }
        URI sampleURI = requestUriKeyPairs.get(0).getRequestUri();
        String serviceName = LoadBalancerUtil.getServiceNameFromUri(sampleURI);
        PartitionAccessor accessor = this._partitionInfoProvider.getPartitionAccessor(serviceName);
        Map<Integer, Ring<URI>> rings = this._hashRingProvider.getRings(sampleURI);
        HashFunction<Request> hashFunction = this._hashRingProvider.getRequestHashFunction(serviceName);
        HashMap<Integer, Set<KEY>> unmapped = new HashMap<Integer, Set<KEY>>();
        Map<Integer, List<URIKeyPair<KEY>>> requestsByPartition = this.distributeToPartitions(requestUriKeyPairs, accessor, unmapped);
        HashMap<URI, Integer> hostToParitionId = new HashMap<URI, Integer>();
        Map<URI, Set<KEY>> hostToKeySet = this.distributeToHosts(requestsByPartition, rings, hashFunction, hostToParitionId, unmapped);
        return new URIMappingResult<KEY>(hostToKeySet, unmapped, hostToParitionId);
    }

    @Override
    public boolean needScatterGather(String serviceName) throws ServiceUnavailableException {
        return this.isPartitioningEnabled(serviceName) || this.isStickyEnabled(serviceName);
    }

    private boolean isStickyEnabled(String serviceName) throws ServiceUnavailableException {
        HashFunction<Request> hashFunction = this._hashRingProvider.getRequestHashFunction(serviceName);
        return hashFunction instanceof URIRegexHash;
    }

    private boolean isPartitioningEnabled(String serviceName) throws ServiceUnavailableException {
        PartitionAccessor accessor = this._partitionInfoProvider.getPartitionAccessor(serviceName);
        return accessor.getMaxPartitionId() > 0;
    }

    private <KEY> Map<Integer, List<URIKeyPair<KEY>>> distributeToPartitions(List<URIKeyPair<KEY>> requestUriKeyPairs, PartitionAccessor accessor, Map<Integer, Set<KEY>> unmapped) {
        if (accessor.getMaxPartitionId() == 0) {
            return this.distributeToPartitionsUnpartitioned(requestUriKeyPairs);
        }
        if (this.checkPartitionIdOverride(requestUriKeyPairs)) {
            return this.doPartitionIdOverride(requestUriKeyPairs.get(0));
        }
        HashMap requestListsByPartitionId = new HashMap();
        requestUriKeyPairs.forEach(request -> {
            try {
                int partitionId = accessor.getPartitionId(request.getRequestUri());
                requestListsByPartitionId.putIfAbsent(partitionId, new ArrayList());
                ((List)requestListsByPartitionId.get(partitionId)).add(request);
            }
            catch (PartitionAccessException e) {
                unmapped.computeIfAbsent(-1, k -> new HashSet()).add(request.getKey());
            }
        });
        return requestListsByPartitionId;
    }

    private <KEY> Map<Integer, List<URIKeyPair<KEY>>> distributeToPartitionsUnpartitioned(List<URIKeyPair<KEY>> requestUriKeyPairs) {
        return Collections.singletonMap(0, requestUriKeyPairs);
    }

    private <KEY> Map<URI, Set<KEY>> distributeToHosts(Map<Integer, List<URIKeyPair<KEY>>> requestsByParititonId, Map<Integer, Ring<URI>> rings, HashFunction<Request> hashFunction, Map<URI, Integer> hostToPartitionId, Map<Integer, Set<KEY>> unmapped) {
        if (hashFunction instanceof RandomHash) {
            return this.distributeToHostNonSticky(requestsByParititonId, rings, hostToPartitionId, unmapped);
        }
        HashMap<URI, Set<KEY>> hostToKeySet = new HashMap<URI, Set<KEY>>();
        block0: for (Map.Entry<Integer, List<URIKeyPair<KEY>>> entry : requestsByParititonId.entrySet()) {
            int partitionId = entry.getKey();
            for (URIKeyPair<KEY> request : entry.getValue()) {
                int hashcode = hashFunction.hash(new URIRequest(request.getRequestUri()));
                URI resolvedHost = rings.get(partitionId).get(hashcode);
                if (resolvedHost == null) {
                    Set<KEY> unmappedKeys = RingBasedUriMapper.convertURIKeyPairListToKeySet(entry.getValue());
                    unmapped.computeIfAbsent(entry.getKey(), k -> new HashSet()).addAll(unmappedKeys);
                    continue block0;
                }
                hostToPartitionId.putIfAbsent(resolvedHost, entry.getKey());
                Set newSet = hostToKeySet.computeIfAbsent(resolvedHost, host -> new HashSet());
                if (request.getKey() == null) continue;
                newSet.add(request.getKey());
            }
        }
        return hostToKeySet;
    }

    private <KEY> Map<URI, Set<KEY>> distributeToHostNonSticky(Map<Integer, List<URIKeyPair<KEY>>> requestsByParititonId, Map<Integer, Ring<URI>> rings, Map<URI, Integer> hostToPartitionId, Map<Integer, Set<KEY>> unmapped) {
        HashMap<URI, Set<KEY>> hostToKeySet = new HashMap<URI, Set<KEY>>();
        for (Map.Entry<Integer, List<URIKeyPair<KEY>>> entry : requestsByParititonId.entrySet()) {
            URI resolvedHost = rings.get(entry.getKey()).get(ThreadLocalRandom.current().nextInt());
            Set<KEY> allKeys = RingBasedUriMapper.convertURIKeyPairListToKeySet(entry.getValue());
            if (resolvedHost == null) {
                unmapped.computeIfAbsent(entry.getKey(), k -> new HashSet()).addAll(allKeys);
                continue;
            }
            hostToPartitionId.putIfAbsent(resolvedHost, entry.getKey());
            hostToKeySet.computeIfAbsent(resolvedHost, host -> new HashSet()).addAll(allKeys);
        }
        return hostToKeySet;
    }

    private static <KEY> Set<KEY> convertURIKeyPairListToKeySet(List<URIKeyPair<KEY>> list) {
        if (list.stream().anyMatch(uriKeyPair -> uriKeyPair.getKey() == null)) {
            return Collections.emptySet();
        }
        return list.stream().map(URIKeyPair::getKey).collect(Collectors.toSet());
    }

    private <KEY> boolean checkPartitionIdOverride(List<URIKeyPair<KEY>> requests) {
        if (requests.stream().anyMatch(URIKeyPair::hasOverriddenPartitionIds)) {
            if (requests.size() == 1) {
                LOG.debug("Use partition ids provided by custom scatter gather strategy");
                return true;
            }
            throw new IllegalStateException("More than one request with overridden partition ids are provided. Consider put all partition ids in one set or send different request if URI is different");
        }
        return false;
    }

    private <KEY> Map<Integer, List<URIKeyPair<KEY>>> doPartitionIdOverride(URIKeyPair<KEY> request) {
        return request.getOverriddenPartitionIds().stream().collect(Collectors.toMap(Function.identity(), partitionId -> Collections.singletonList(request)));
    }
}

