/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.d2.balancer.clients;

import com.linkedin.common.callback.Callback;
import com.linkedin.common.util.None;
import com.linkedin.d2.balancer.clients.TrackerClient;
import com.linkedin.d2.balancer.properties.PartitionData;
import com.linkedin.d2.balancer.util.LoadBalancerUtil;
import com.linkedin.d2.discovery.util.LogUtil;
import com.linkedin.data.ByteString;
import com.linkedin.r2.RemoteInvocationException;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.rest.RestException;
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.stream.StreamException;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.stream.entitystream.EntityStream;
import com.linkedin.r2.message.stream.entitystream.Observer;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import com.linkedin.r2.transport.common.bridge.common.TransportCallback;
import com.linkedin.r2.transport.common.bridge.common.TransportResponse;
import com.linkedin.util.clock.Clock;
import com.linkedin.util.degrader.CallCompletion;
import com.linkedin.util.degrader.CallTracker;
import com.linkedin.util.degrader.CallTrackerImpl;
import com.linkedin.util.degrader.ErrorType;
import java.net.ConnectException;
import java.net.URI;
import java.nio.channels.ClosedChannelException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeoutException;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrackerClientImpl
implements TrackerClient {
    public static final String DEFAULT_ERROR_STATUS_REGEX = "(5..)";
    public static final Pattern DEFAULT_ERROR_STATUS_PATTERN = Pattern.compile("(5..)");
    public static final long DEFAULT_CALL_TRACKER_INTERVAL = 5000L;
    private static final Logger _log = LoggerFactory.getLogger(TrackerClient.class);
    private final TransportClient _transportClient;
    private final Map<Integer, PartitionData> _partitionData;
    private final URI _uri;
    private final Predicate<Integer> _isErrorStatus;
    private final ConcurrentMap<Integer, Double> _subsetWeightMap;
    private final boolean _doNotLoadBalance;
    final CallTracker _callTracker;
    private boolean _doNotSlowStart;
    private volatile CallTracker.CallStats _latestCallStats;

    public TrackerClientImpl(URI uri, Map<Integer, PartitionData> partitionDataMap, TransportClient transportClient, Clock clock, long interval, Predicate<Integer> isErrorStatus) {
        this(uri, partitionDataMap, transportClient, clock, interval, isErrorStatus, true, false, false);
    }

    public TrackerClientImpl(URI uri, Map<Integer, PartitionData> partitionDataMap, TransportClient transportClient, Clock clock, long interval, Predicate<Integer> isErrorStatus, boolean percentileTrackingEnabled, boolean doNotSlowStart, boolean doNotLoadBalance) {
        this._uri = uri;
        this._transportClient = transportClient;
        this._callTracker = new CallTrackerImpl(interval, clock, percentileTrackingEnabled);
        this._isErrorStatus = isErrorStatus;
        this._partitionData = Collections.unmodifiableMap(partitionDataMap);
        this._latestCallStats = this._callTracker.getCallStats();
        this._doNotSlowStart = doNotSlowStart;
        this._subsetWeightMap = new ConcurrentHashMap<Integer, Double>();
        this._doNotLoadBalance = doNotLoadBalance;
        this._callTracker.addStatsRolloverEventListener(event -> {
            this._latestCallStats = event.getCallStats();
        });
        LogUtil.debug(_log, "created tracker client: ", this);
    }

    @Override
    public CallTracker.CallStats getLatestCallStats() {
        return this._latestCallStats;
    }

    public void shutdown(Callback<None> callback) {
        this._transportClient.shutdown(callback);
    }

    @Override
    public TransportClient getTransportClient() {
        return this._transportClient;
    }

    @Override
    public Map<Integer, PartitionData> getPartitionDataMap() {
        return this._partitionData;
    }

    @Override
    public void setSubsetWeight(int partitionId, double partitionWeight) {
        this._subsetWeightMap.put(partitionId, partitionWeight);
    }

    @Override
    public double getSubsetWeight(int partitionId) {
        return this._subsetWeightMap.getOrDefault(partitionId, 1.0);
    }

    public void restRequest(RestRequest request, RequestContext requestContext, Map<String, String> wireAttrs, TransportCallback<RestResponse> callback) {
        this._transportClient.restRequest(request, requestContext, wireAttrs, (TransportCallback)new TrackerClientRestCallback(callback, this._callTracker.startCall()));
    }

    public void streamRequest(StreamRequest request, RequestContext requestContext, Map<String, String> wireAttrs, TransportCallback<StreamResponse> callback) {
        this._transportClient.streamRequest(request, requestContext, wireAttrs, (TransportCallback)new TrackerClientStreamCallback(callback, this._callTracker.startCall()));
    }

    @Override
    public URI getUri() {
        return this._uri;
    }

    @Override
    public CallTracker getCallTracker() {
        return this._callTracker;
    }

    public String toString() {
        return this.getClass().getSimpleName() + " [_uri=" + this._uri + ", _partitionData=" + this._partitionData + "]";
    }

    @Override
    public void setDoNotSlowStart(boolean doNotSlowStart) {
        this._doNotSlowStart = doNotSlowStart;
    }

    @Override
    public boolean doNotSlowStart() {
        return this._doNotSlowStart;
    }

    @Override
    public boolean doNotLoadBalance() {
        return this._doNotLoadBalance;
    }

    private void handleError(CallCompletion callCompletion, Throwable throwable) {
        if (this.isServerError(throwable)) {
            callCompletion.endCallWithError(ErrorType.SERVER_ERROR);
        } else if (throwable instanceof RemoteInvocationException) {
            Throwable originalThrowable = LoadBalancerUtil.findOriginalThrowable(throwable);
            if (originalThrowable instanceof ConnectException) {
                callCompletion.endCallWithError(ErrorType.CONNECT_EXCEPTION);
            } else if (originalThrowable instanceof ClosedChannelException) {
                callCompletion.endCallWithError(ErrorType.CLOSED_CHANNEL_EXCEPTION);
            } else if (originalThrowable instanceof TimeoutException) {
                callCompletion.endCallWithError(ErrorType.TIMEOUT_EXCEPTION);
            } else {
                callCompletion.endCallWithError(ErrorType.REMOTE_INVOCATION_EXCEPTION);
            }
        } else {
            callCompletion.endCallWithError();
        }
    }

    private boolean isServerError(Throwable throwable) {
        StreamException streamException;
        if (throwable instanceof RestException) {
            RestException restException = (RestException)throwable;
            if (restException.getResponse() != null) {
                return this._isErrorStatus.test(restException.getResponse().getStatus());
            }
        } else if (throwable instanceof StreamException && (streamException = (StreamException)throwable).getResponse() != null) {
            return this._isErrorStatus.test(streamException.getResponse().getStatus());
        }
        return false;
    }

    private class TrackerClientStreamCallback
    implements TransportCallback<StreamResponse> {
        private TransportCallback<StreamResponse> _wrappedCallback;
        private CallCompletion _callCompletion;

        public TrackerClientStreamCallback(TransportCallback<StreamResponse> wrappedCallback, CallCompletion callCompletion) {
            this._wrappedCallback = wrappedCallback;
            this._callCompletion = callCompletion;
        }

        public void onResponse(TransportResponse<StreamResponse> response) {
            if (response.hasError()) {
                Throwable throwable = response.getError();
                TrackerClientImpl.this.handleError(this._callCompletion, throwable);
            } else {
                EntityStream entityStream = ((StreamResponse)response.getResponse()).getEntityStream();
                this._callCompletion.record();
                Observer observer = new Observer(){

                    public void onDataAvailable(ByteString data) {
                    }

                    public void onDone() {
                        TrackerClientStreamCallback.this._callCompletion.endCall();
                    }

                    public void onError(Throwable e) {
                        TrackerClientImpl.this.handleError(TrackerClientStreamCallback.this._callCompletion, e);
                    }
                };
                entityStream.addObserver(observer);
            }
            this._wrappedCallback.onResponse(response);
        }
    }

    private class TrackerClientRestCallback
    implements TransportCallback<RestResponse> {
        private TransportCallback<RestResponse> _wrappedCallback;
        private CallCompletion _callCompletion;

        public TrackerClientRestCallback(TransportCallback<RestResponse> wrappedCallback, CallCompletion callCompletion) {
            this._wrappedCallback = wrappedCallback;
            this._callCompletion = callCompletion;
        }

        public void onResponse(TransportResponse<RestResponse> response) {
            if (response.hasError()) {
                Throwable throwable = response.getError();
                TrackerClientImpl.this.handleError(this._callCompletion, throwable);
            } else {
                this._callCompletion.endCall();
            }
            this._wrappedCallback.onResponse(response);
        }
    }
}

