/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.parallelconsumer.integrationTests;

import io.confluent.csid.utils.ProgressBarUtils;
import io.confluent.csid.utils.ProgressTracker;
import io.confluent.csid.utils.StringUtils;
import io.confluent.csid.utils.TrimListRepresentation;
import io.confluent.parallelconsumer.ParallelConsumerOptions;
import io.confluent.parallelconsumer.ParallelEoSStreamProcessor;
import io.confluent.parallelconsumer.integrationTests.BrokerIntegrationTest;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Properties;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import me.tongfei.progressbar.ProgressBar;
import org.apache.commons.lang3.RandomUtils;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.assertj.core.api.AbstractCollectionAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.BooleanAssert;
import org.assertj.core.api.CollectionAssert;
import org.assertj.core.api.ListAssert;
import org.assertj.core.api.SoftAssertions;
import org.assertj.core.internal.StandardComparisonStrategy;
import org.assertj.core.presentation.Representation;
import org.assertj.core.util.IterableUtil;
import org.awaitility.Awaitility;
import org.awaitility.core.TerminalFailureException;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import pl.tlinkowski.unij.api.UniLists;

public class MultiInstanceRebalanceTest
extends BrokerIntegrationTest<String, String> {
    private static final Logger log = LoggerFactory.getLogger(MultiInstanceRebalanceTest.class);
    static final int DEFAULT_MAX_POLL = 500;
    public static final int CHAOS_FREQUENCY = 500;
    public static final int DEFAULT_POLL_DELAY = 150;
    AtomicInteger count = new AtomicInteger();
    ProgressBar overallProgress;
    Set<String> overallConsumedKeys = new ConcurrentSkipListSet<String>();
    int pcInstanceCount = 0;

    @ParameterizedTest
    @EnumSource(value=ParallelConsumerOptions.ProcessingOrder.class)
    void consumeWithMultipleInstancesPeriodicConsumerSync(ParallelConsumerOptions.ProcessingOrder order) {
        this.numPartitions = 2;
        int expectedMessageCount = order == ParallelConsumerOptions.ProcessingOrder.PARTITION ? 100 : 1000;
        int numberOfPcsToRun = 2;
        this.runTest(500, ParallelConsumerOptions.CommitMode.PERIODIC_CONSUMER_SYNC, order, expectedMessageCount, numberOfPcsToRun, 1.0, 150);
    }

    @ParameterizedTest
    @EnumSource(value=ParallelConsumerOptions.ProcessingOrder.class)
    void consumeWithMultipleInstancesPeriodicConsumerAsynchronous(ParallelConsumerOptions.ProcessingOrder order) {
        this.numPartitions = 2;
        int expectedMessageCount = order == ParallelConsumerOptions.ProcessingOrder.PARTITION ? 100 : 1000;
        this.runTest(500, ParallelConsumerOptions.CommitMode.PERIODIC_CONSUMER_ASYNCHRONOUS, order, expectedMessageCount, 2, 1.0, 150);
    }

    @Disabled
    @Test
    void largeNumberOfInstances() {
        this.numPartitions = 80;
        int numberOfPcsToRun = 12;
        int expectedMessageCount = 500000;
        this.runTest(500, ParallelConsumerOptions.CommitMode.PERIODIC_CONSUMER_ASYNCHRONOUS, ParallelConsumerOptions.ProcessingOrder.UNORDERED, expectedMessageCount, numberOfPcsToRun, 0.3, 1);
    }

    private void runTest(int maxPoll, ParallelConsumerOptions.CommitMode commitMode, ParallelConsumerOptions.ProcessingOrder order, final int expectedMessageCount, int numberOfPcsToRun, double fractionOfMessagesToPreProduce, int pollDelayMs) {
        final String inputName = this.setupTopic(this.getClass().getSimpleName() + "-input-" + RandomUtils.nextInt());
        this.overallProgress = ProgressBarUtils.getNewMessagesBar("overall", log, expectedMessageCount);
        final ExecutorService pcExecutor = Executors.newWorkStealingPool();
        final ProgressBar sendingProgress = ProgressBarUtils.getNewMessagesBar("sending", log, expectedMessageCount);
        final ConcurrentSkipListSet<String> expectedKeys = new ConcurrentSkipListSet<String>();
        log.info("Producing {} messages before starting test", (Object)expectedMessageCount);
        final ArrayList<Future> sends = new ArrayList<Future>();
        final int preProduceCount = (int)((double)expectedMessageCount * fractionOfMessagesToPreProduce);
        try (KafkaProducer kafkaProducer = this.getKcu().createNewProducer(false);){
            for (int i = 0; i < preProduceCount; ++i) {
                String key = "key-" + i;
                Future send = kafkaProducer.send(new ProducerRecord(inputName, (Object)key, (Object)("value-" + i)), (meta, exception) -> {
                    if (exception != null) {
                        log.error("Error sending, ", (Throwable)exception);
                    }
                    sendingProgress.step();
                });
                sends.add(send);
                expectedKeys.add(key);
            }
            log.debug("Finished sending test data");
        }
        log.debug("Waiting for broker acks");
        for (Future send : sends) {
            send.get();
        }
        Assertions.assertThat(sends).hasSizeGreaterThanOrEqualTo(preProduceCount);
        log.info("Running first instance of pc");
        int expectedMessageCountPerPC = expectedMessageCount / numberOfPcsToRun;
        ParallelConsumerRunnable pc1 = new ParallelConsumerRunnable(maxPoll, commitMode, order, inputName, expectedMessageCountPerPC, pollDelayMs);
        pcExecutor.submit(pc1);
        Awaitility.waitAtMost((Duration)Duration.ofSeconds(10L)).until(() -> pc1.getConsumedKeys().size() > 1);
        Runnable sender = new Runnable(){

            @Override
            public void run() {
                log.info("Producing {} messages before starting test", (Object)expectedMessageCount);
                try (KafkaProducer kafkaProducer = MultiInstanceRebalanceTest.this.getKcu().createNewProducer(false);){
                    for (int i = preProduceCount; i < expectedMessageCount; ++i) {
                        String key = "key-" + i;
                        log.debug("sending {}", (Object)key);
                        Future send = kafkaProducer.send(new ProducerRecord(inputName, (Object)key, (Object)("value-" + i)), (meta, exception) -> {
                            if (exception != null) {
                                log.error("Error sending, ", (Throwable)exception);
                            }
                            sendingProgress.step();
                        });
                        send.get();
                        sends.add(send);
                        expectedKeys.add(key);
                    }
                    log.info("Finished sending test data");
                }
            }
        };
        pcExecutor.submit(sender);
        final List<ParallelConsumerRunnable> secondaryPcs = Collections.synchronizedList(IntStream.range(1, numberOfPcsToRun).mapToObj(value -> {
            try {
                int jitterRangeMs = 2;
                Thread.sleep((int)(Math.random() * (double)jitterRangeMs));
            }
            catch (InterruptedException e) {
                log.error(e.getMessage(), (Throwable)e);
            }
            log.info("Running pc instance {}", (Object)value);
            ParallelConsumerRunnable instance = new ParallelConsumerRunnable(maxPoll, commitMode, order, inputName, expectedMessageCountPerPC, pollDelayMs);
            pcExecutor.submit(instance);
            return instance;
        }).collect(Collectors.toList()));
        final List<ParallelConsumerRunnable> allPCRunners = Collections.synchronizedList(new ArrayList());
        allPCRunners.add(pc1);
        allPCRunners.addAll(secondaryPcs);
        ParallelConsumerRunnable[] parallelConsumerRunnablesArray = allPCRunners.toArray(new ParallelConsumerRunnable[0]);
        Runnable chaosMonkey = new Runnable(){

            @Override
            public void run() {
                try {
                    while (MultiInstanceRebalanceTest.this.noneHaveFailed(allPCRunners)) {
                        Thread.sleep((int)(500.0 * Math.random()));
                        boolean makeChaos = Math.random() > 0.2;
                        if (!makeChaos) continue;
                        int size = secondaryPcs.size();
                        int numberToMessWith = (int)(Math.random() * (double)size * 0.6);
                        if (numberToMessWith <= 0) continue;
                        log.info("Will mess with {} instances", (Object)numberToMessWith);
                        IntStream.range(0, numberToMessWith).forEach(value -> {
                            int instanceToGet = (int)((double)(size - 1) * Math.random());
                            ParallelConsumerRunnable victim = (ParallelConsumerRunnable)secondaryPcs.get(instanceToGet);
                            log.info("Victim is instance: " + victim.instanceId);
                            victim.toggle(pcExecutor);
                        });
                    }
                }
                catch (Throwable e) {
                    log.error("Error in chaos loop", e);
                    throw new RuntimeException(e);
                }
                log.error("Ending chaos as a PC instance has died");
            }
        };
        pcExecutor.submit(chaosMonkey);
        Assertions.useRepresentation((Representation)new TrimListRepresentation());
        String failureMessage = StringUtils.msg((String)"All keys sent to input-topic should be processed, within time (expected: {} commit: {} order: {} max poll: {})", (Object[])new Object[]{expectedMessageCount, commitMode, order, maxPoll});
        ProgressTracker progressTracker = new ProgressTracker(this.count);
        try {
            Awaitility.waitAtMost((Duration)Duration.ofMinutes(5L)).failFast("A PC has died - check logs", () -> !this.noneHaveFailed(allPCRunners)).alias(failureMessage).pollInterval(1L, TimeUnit.SECONDS).untilAsserted(() -> {
                log.trace("Processed-count: {}", (Object)this.getAllConsumedKeys(parallelConsumerRunnablesArray).size());
                if (progressTracker.hasProgressNotBeenMade()) {
                    expectedKeys.removeAll(this.getAllConsumedKeys(parallelConsumerRunnablesArray));
                    throw progressTracker.constructError(StringUtils.msg((String)"No progress, missing keys: {}.", (Object[])new Object[]{expectedKeys}));
                }
                SoftAssertions all = new SoftAssertions();
                ((BooleanAssert)all.assertThat(this.overallConsumedKeys.containsAll(expectedKeys)).as("contains all: all expected are consumed at least once", new Object[0])).isTrue();
                ((CollectionAssert)all.assertThat(this.overallConsumedKeys).as("size: all expected are consumed only once", new Object[0])).hasSizeGreaterThanOrEqualTo(expectedKeys.size());
                all.assertAll();
            });
        }
        catch (Throwable error) {
            List<Exception> exceptions = this.checkForFailure(allPCRunners);
            if (error instanceof TerminalFailureException) {
                Optional any = exceptions.stream().findAny();
                String message = StringUtils.msg((String)"{} \n Terminal failure in one or more of the PCs. Reported exception states are: {} \n {}", (Object[])new Object[]{failureMessage, exceptions, error});
                throw new RuntimeException(message, any.orElse(null));
            }
            String message = StringUtils.msg((String)"{} \n Assertion error. PC reported exception states: {} \n {}", (Object[])new Object[]{failureMessage, exceptions, error});
            throw new RuntimeException(message, error);
        }
        finally {
            this.overallProgress.close();
            sendingProgress.close();
        }
        allPCRunners.forEach(ParallelConsumerRunnable::close);
        Assertions.assertThat((Collection)pc1.consumedKeys).hasSizeGreaterThan(0);
        ((ListAssert)Assertions.assertThat(this.getAllConsumedKeys(secondaryPcs.toArray(new ParallelConsumerRunnable[0]))).as("Second PC should have taken over some of the work and consumed some records", new Object[0])).hasSizeGreaterThan(0);
        pcExecutor.shutdown();
        Collection duplicates = IterableUtil.toCollection((Iterable)StandardComparisonStrategy.instance().duplicatesFrom(this.getAllConsumedKeys(parallelConsumerRunnablesArray)));
        log.info("Duplicate consumed keys (at least one is expected due to the rebalance): {}", (Object)duplicates);
        double percentageDuplicateTolerance = 0.2;
        ((AbstractCollectionAssert)Assertions.assertThat((Collection)duplicates).as("There should be few duplicate keys", new Object[0])).hasSizeLessThan((int)((double)expectedMessageCount * percentageDuplicateTolerance));
    }

    private boolean noneHaveFailed(List<ParallelConsumerRunnable> secondaryPcs) {
        return this.checkForFailure(secondaryPcs).isEmpty();
    }

    private List<Exception> checkForFailure(List<ParallelConsumerRunnable> secondaryPcs) {
        return secondaryPcs.stream().filter(pcr -> {
            ParallelEoSStreamProcessor<String, String> pc = pcr.getParallelConsumer();
            if (pc == null) {
                return false;
            }
            if (!pc.isClosedOrFailed()) {
                return false;
            }
            boolean failed = pc.getFailureCause() != null;
            return failed;
        }).map(pc -> pc.getParallelConsumer().getFailureCause()).collect(Collectors.toList());
    }

    List<String> getAllConsumedKeys(ParallelConsumerRunnable ... instances) {
        return Arrays.stream(instances).flatMap(parallelConsumerRunnable -> ((ParallelConsumerRunnable)parallelConsumerRunnable).consumedKeys.stream()).collect(Collectors.toList());
    }

    static {
        MDC.put((String)"pcId", (String)"Test-Thread");
    }

    public class ParallelConsumerRunnable
    implements Runnable {
        private final int instanceId;
        private final int maxPoll;
        private final ParallelConsumerOptions.CommitMode commitMode;
        private final ParallelConsumerOptions.ProcessingOrder order;
        private final String inputTopic;
        private final int expectedMessageCount;
        private final ProgressBar bar;
        private final int pollDelayMs;
        private ParallelEoSStreamProcessor<String, String> parallelConsumer;
        private boolean started = false;
        private final Queue<String> consumedKeys = new ConcurrentLinkedQueue<String>();

        public ParallelConsumerRunnable(int maxPoll, ParallelConsumerOptions.CommitMode commitMode, ParallelConsumerOptions.ProcessingOrder order, String inputTopic, int expectedMessageCount, int pollDelayMs) {
            this.maxPoll = maxPoll;
            this.commitMode = commitMode;
            this.order = order;
            this.inputTopic = inputTopic;
            this.expectedMessageCount = expectedMessageCount;
            this.pollDelayMs = pollDelayMs;
            this.instanceId = MultiInstanceRebalanceTest.this.pcInstanceCount++;
            this.bar = ProgressBarUtils.getNewMessagesBar("PC" + this.instanceId, log, expectedMessageCount);
        }

        @Override
        public void run() {
            MDC.put((String)"pcId", (String)("Runner-" + this.instanceId));
            this.started = true;
            log.info("Running consumer!");
            Properties consumerProps = new Properties();
            consumerProps.put("max.poll.records", (Object)this.maxPoll);
            KafkaConsumer newConsumer = MultiInstanceRebalanceTest.this.getKcu().createNewConsumer(false, consumerProps);
            this.parallelConsumer = new ParallelEoSStreamProcessor(ParallelConsumerOptions.builder().ordering(this.order).consumer(newConsumer).commitMode(this.commitMode).maxConcurrency(10).build());
            this.parallelConsumer.setTimeBetweenCommits(Duration.ofSeconds(1L));
            this.parallelConsumer.setMyId(Optional.of("PC-" + this.instanceId));
            this.parallelConsumer.subscribe((Collection)UniLists.of((Object)this.inputTopic));
            this.parallelConsumer.poll(record -> {
                try {
                    Thread.sleep(this.pollDelayMs);
                }
                catch (InterruptedException interruptedException) {
                    // empty catch block
                }
                MultiInstanceRebalanceTest.this.count.incrementAndGet();
                this.bar.step();
                MultiInstanceRebalanceTest.this.overallProgress.step();
                this.consumedKeys.add((String)record.key());
                MultiInstanceRebalanceTest.this.overallConsumedKeys.add((String)record.key());
            });
        }

        public void stop() {
            log.info("Stopping {}", (Object)this.instanceId);
            this.started = false;
            this.parallelConsumer.close();
        }

        public void start(ExecutorService pcExecutor) {
            Exception failureCause = this.getParallelConsumer().getFailureCause();
            if (failureCause != null) {
                throw new RuntimeException("Error starting PC, pc died from previous error: " + failureCause.getMessage(), failureCause);
            }
            log.info("Starting {}", (Object)this);
            pcExecutor.submit(this);
        }

        public void close() {
            log.info("Stopping {}", (Object)this);
            this.stop();
            this.bar.close();
        }

        public void toggle(ExecutorService pcExecutor) {
            if (this.started) {
                this.stop();
            } else {
                this.start(pcExecutor);
            }
        }

        public int getInstanceId() {
            return this.instanceId;
        }

        public int getMaxPoll() {
            return this.maxPoll;
        }

        public ParallelConsumerOptions.CommitMode getCommitMode() {
            return this.commitMode;
        }

        public ParallelConsumerOptions.ProcessingOrder getOrder() {
            return this.order;
        }

        public String getInputTopic() {
            return this.inputTopic;
        }

        public int getExpectedMessageCount() {
            return this.expectedMessageCount;
        }

        public ProgressBar getBar() {
            return this.bar;
        }

        public int getPollDelayMs() {
            return this.pollDelayMs;
        }

        public ParallelEoSStreamProcessor<String, String> getParallelConsumer() {
            return this.parallelConsumer;
        }

        public boolean isStarted() {
            return this.started;
        }

        public Queue<String> getConsumedKeys() {
            return this.consumedKeys;
        }

        public String toString() {
            return "MultiInstanceRebalanceTest.ParallelConsumerRunnable(instanceId=" + this.getInstanceId() + ", maxPoll=" + this.getMaxPoll() + ", commitMode=" + this.getCommitMode() + ", order=" + this.getOrder() + ", inputTopic=" + this.getInputTopic() + ", expectedMessageCount=" + this.getExpectedMessageCount() + ", bar=" + this.getBar() + ", pollDelayMs=" + this.getPollDelayMs() + ", parallelConsumer=" + this.getParallelConsumer() + ", started=" + this.isStarted() + ")";
        }
    }
}

