package com.uber.rss.tools;

import com.uber.rss.clients.ClientRetryOptions;
import com.uber.rss.clients.MultiServerSocketReadClient;
import com.uber.rss.clients.ReadClientDataOptions;
import com.uber.rss.clients.ServerReplicationGroupUtil;
import com.uber.rss.clients.TaskDataBlock;
import com.uber.rss.common.AppShuffleId;
import com.uber.rss.common.AppShufflePartitionId;
import com.uber.rss.common.ServerDetail;
import com.uber.rss.common.ServerReplicationGroup;
import com.uber.rss.storage.ShuffleFileUtils;
import com.uber.rss.util.ServerHostAndPort;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/uber/rss/tools/StreamReadClientVerify.class */
public class StreamReadClientVerify {
    private static final Logger logger = LoggerFactory.getLogger(StreamReadClientVerify.class);
    private String appId;
    private String appAttempt;
    private int shuffleId;
    private int numPartitions;
    private Map<Integer, Long> expectedTotalRecordsForEachPartition;
    private List<ServerDetail> rssServers = new ArrayList();
    private int numReplicas = 1;
    private int partitionFanout = 1;
    private long expectedTotalRecords = 0;
    private int maxValueLen = ShuffleFileUtils.MAX_SPLITS;
    private Runnable actionToSimulateBadServer = null;

    public void setRssServers(List<ServerDetail> list, int i) {
        this.rssServers = new ArrayList(list);
        this.numReplicas = i;
    }

    public void setActionToSimulateBadServer(Runnable runnable) {
        this.actionToSimulateBadServer = runnable;
    }

    public void setAppShuffleId(AppShuffleId appShuffleId) {
        this.appId = appShuffleId.getAppId();
        this.appAttempt = appShuffleId.getAppAttempt();
        this.shuffleId = appShuffleId.getShuffleId();
    }

    public void setNumPartitions(int i) {
        this.numPartitions = i;
    }

    public void setPartitionFanout(int i) {
        this.partitionFanout = i;
    }

    public void setExpectedTotalRecords(long j) {
        this.expectedTotalRecords = j;
    }

    public void setExpectedTotalRecordsForEachPartition(Map<Integer, Long> map) {
        this.expectedTotalRecordsForEachPartition = map;
    }

    public void verifyRecords(Collection<Integer> collection, Collection<Long> collection2) {
        AtomicLong atomicLong = new AtomicLong();
        if (collection == null) {
            collection = (Collection) IntStream.range(0, this.numPartitions).boxed().collect(Collectors.toList());
            logger.info(String.format("Verifying record for partitions: [%s, %s)", 0, Integer.valueOf(this.numPartitions)));
        } else {
            logger.info(String.format("Verifying record for partitions: %s", StringUtils.join(collection, ",")));
        }
        Iterator<Integer> it = collection.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            AppShufflePartitionId appShufflePartitionId = new AppShufflePartitionId(this.appId, this.appAttempt, this.shuffleId, intValue);
            int i = 120000 * 3;
            List<ServerReplicationGroup> createReplicationGroupsForPartition = ServerReplicationGroupUtil.createReplicationGroupsForPartition(this.rssServers, this.numReplicas, intValue, this.partitionFanout);
            MultiServerSocketReadClient multiServerSocketReadClient = new MultiServerSocketReadClient(createReplicationGroupsForPartition, 120000, new ClientRetryOptions(10, i), "user1", appShufflePartitionId, new ReadClientDataOptions(collection2, 10, i), true);
            logger.info(String.format("Connecting replicated read client: %s", multiServerSocketReadClient));
            multiServerSocketReadClient.connect();
            try {
                long j = 0;
                for (TaskDataBlock readDataBlock = multiServerSocketReadClient.readDataBlock(); readDataBlock != null; readDataBlock = multiServerSocketReadClient.readDataBlock()) {
                    j++;
                    if (atomicLong.incrementAndGet() == this.expectedTotalRecords / 2 && this.actionToSimulateBadServer != null) {
                        logger.info("Simulate bad server during shuffle read");
                        this.actionToSimulateBadServer.run();
                    }
                    if (readDataBlock.getPayload() != null && readDataBlock.getPayload().length > this.maxValueLen) {
                        throw new RuntimeException(String.format("Read wrong value len %s after reading %s records for %s from server %s", readDataBlock.getPayload(), Long.valueOf(j), appShufflePartitionId, createReplicationGroupsForPartition));
                    }
                }
                logger.info(String.format("Closing read client for %s", appShufflePartitionId));
                long longValue = this.expectedTotalRecordsForEachPartition.getOrDefault(Integer.valueOf(intValue), 0L).longValue();
                if (j != longValue) {
                    throw new RuntimeException(String.format("Verify error for partition %s, servers %s, expected records: %s, actual records: %s", appShufflePartitionId, createReplicationGroupsForPartition, Long.valueOf(longValue), Long.valueOf(j)));
                }
                logger.info(String.format("Verified %s records for %s from server %s", Long.valueOf(j), appShufflePartitionId, createReplicationGroupsForPartition));
                multiServerSocketReadClient.close();
            } catch (Throwable th) {
                multiServerSocketReadClient.close();
                throw th;
            }
        }
        String format = String.format("Total expected records: %s, total records read from servers: %s", Long.valueOf(this.expectedTotalRecords), atomicLong);
        logger.info(format);
        if (this.expectedTotalRecords != 0 && this.expectedTotalRecords != atomicLong.get()) {
            throw new RuntimeException(format);
        }
    }

    public static void main(String[] strArr) {
        StreamReadClientVerify streamReadClientVerify = new StreamReadClientVerify();
        int i = 0;
        while (i < strArr.length) {
            int i2 = i;
            int i3 = i + 1;
            String str = strArr[i2];
            if (str.equalsIgnoreCase("-rssServers")) {
                i = i3 + 1;
                streamReadClientVerify.rssServers.addAll((List) Arrays.asList(strArr[i3].split(":")).stream().map(str2 -> {
                    ServerHostAndPort fromString = ServerHostAndPort.fromString(str2);
                    return TestUtils.getServerDetail(fromString.getHost(), fromString.getPort());
                }).collect(Collectors.toList()));
            } else if (str.equalsIgnoreCase("-appId")) {
                i = i3 + 1;
                streamReadClientVerify.appId = strArr[i3];
            } else if (str.equalsIgnoreCase("-appAttempt")) {
                i = i3 + 1;
                streamReadClientVerify.appAttempt = strArr[i3];
            } else if (str.equalsIgnoreCase("-shuffleId")) {
                i = i3 + 1;
                streamReadClientVerify.shuffleId = Integer.parseInt(strArr[i3]);
            } else {
                if (!str.equalsIgnoreCase("-expectedTotalRecords")) {
                    throw new IllegalArgumentException("Unsupported argument: " + str);
                }
                i = i3 + 1;
                streamReadClientVerify.expectedTotalRecords = Long.parseLong(strArr[i3]);
            }
        }
        streamReadClientVerify.verifyRecords(null, null);
    }
}
