/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.unsafe.impl.batchimport.cache.idmapping.string;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadLocalRandom;
import org.neo4j.helpers.progress.ProgressListener;
import org.neo4j.unsafe.impl.batchimport.Utils;
import org.neo4j.unsafe.impl.batchimport.cache.IntArray;
import org.neo4j.unsafe.impl.batchimport.cache.LongArray;
import org.neo4j.unsafe.impl.batchimport.cache.idmapping.string.EncodingIdMapper;
import org.neo4j.unsafe.impl.batchimport.cache.idmapping.string.Radix;
import org.neo4j.unsafe.impl.batchimport.cache.idmapping.string.RadixCalculator;
import org.neo4j.unsafe.impl.batchimport.cache.idmapping.string.Workers;

public class ParallelSort {
    private final int[] radixIndexCount;
    private final RadixCalculator radixCalculator;
    private final LongArray dataCache;
    private final long highestSetIndex;
    private final IntArray tracker;
    private final int threads;
    private long[][] sortBuckets;
    private final ProgressListener progress;
    private final Comparator comparator;
    public static final Comparator DEFAULT = new Comparator(){

        @Override
        public boolean lt(long left, long pivot) {
            return Utils.unsignedCompare(left, pivot, Utils.CompareType.LT);
        }

        @Override
        public boolean ge(long right, long pivot) {
            return Utils.unsignedCompare(right, pivot, Utils.CompareType.GE);
        }
    };

    public ParallelSort(Radix radix, LongArray dataCache, long highestSetIndex, IntArray tracker, int threads, ProgressListener progress, Comparator comparator) {
        this.progress = progress;
        this.comparator = comparator;
        this.radixIndexCount = radix.getRadixIndexCounts();
        this.radixCalculator = radix.calculator();
        this.dataCache = dataCache;
        this.highestSetIndex = highestSetIndex;
        this.tracker = tracker;
        this.threads = threads;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public synchronized long[][] run() throws InterruptedException {
        int[][] sortParams = this.sortRadix();
        int threadsNeeded = 0;
        for (int i = 0; i < this.threads && sortParams[i][1] != 0; ++i) {
            ++threadsNeeded;
        }
        CountDownLatch waitSignal = new CountDownLatch(1);
        Workers<SortWorker> sortWorkers = new Workers<SortWorker>("SortWorker");
        this.progress.started("SORT");
        for (int i = 0; i < threadsNeeded && sortParams[i][1] != 0; ++i) {
            sortWorkers.start(new SortWorker(sortParams[i][0], sortParams[i][1], waitSignal));
        }
        waitSignal.countDown();
        try {
            sortWorkers.awaitAndThrowOnError();
        }
        finally {
            this.progress.done();
        }
        return this.sortBuckets;
    }

    private int[][] sortRadix() throws InterruptedException {
        int[][] rangeParams = new int[this.threads][2];
        int[] bucketRange = new int[this.threads];
        Workers<TrackerInitializer> initializers = new Workers<TrackerInitializer>("TrackerInitializer");
        this.sortBuckets = new long[this.threads][2];
        long dataSize = this.highestSetIndex + 1L;
        int bucketSize = Utils.safeCastLongToInt(dataSize / (long)this.threads);
        int count = 0;
        int fullCount = 0;
        this.progress.started("SPLIT");
        int threadIndex = 0;
        for (int i = 0; i < this.radixIndexCount.length && threadIndex < this.threads; ++i) {
            if (count + this.radixIndexCount[i] > bucketSize) {
                bucketRange[threadIndex] = count == 0 ? i : i - 1;
                rangeParams[threadIndex][0] = fullCount;
                if (count != 0) {
                    rangeParams[threadIndex][1] = count;
                    fullCount += count;
                    this.progress.add(count);
                    count = this.radixIndexCount[i];
                } else {
                    rangeParams[threadIndex][1] = this.radixIndexCount[i];
                    fullCount += this.radixIndexCount[i];
                    this.progress.add(this.radixIndexCount[i]);
                }
                initializers.start(new TrackerInitializer(threadIndex, rangeParams[threadIndex], threadIndex > 0 ? bucketRange[threadIndex - 1] : -1, bucketRange[threadIndex], this.sortBuckets[threadIndex]));
                ++threadIndex;
            } else {
                count += this.radixIndexCount[i];
            }
            if (threadIndex != this.threads - 1 && i != this.radixIndexCount.length - 1) continue;
            bucketRange[threadIndex] = this.radixIndexCount.length;
            rangeParams[threadIndex][0] = fullCount;
            rangeParams[threadIndex][1] = Utils.safeCastLongToInt(dataSize - (long)fullCount);
            initializers.start(new TrackerInitializer(threadIndex, rangeParams[threadIndex], threadIndex > 0 ? bucketRange[threadIndex - 1] : -1, bucketRange[threadIndex], this.sortBuckets[threadIndex]));
            break;
        }
        this.progress.done();
        Throwable error = initializers.await();
        int[] bucketIndex = new int[this.threads];
        int i = 0;
        for (TrackerInitializer initializer : initializers) {
            bucketIndex[i++] = initializer.bucketIndex;
        }
        if (error != null) {
            throw new AssertionError(error.getMessage() + "\n" + this.dumpBuckets(rangeParams, bucketRange, bucketIndex), error);
        }
        return rangeParams;
    }

    private String dumpBuckets(int[][] rangeParams, int[] bucketRange, int[] bucketIndex) {
        StringBuilder builder = new StringBuilder();
        builder.append("rangeParams:\n");
        for (int[] range : rangeParams) {
            builder.append("  ").append(Arrays.toString(range)).append("\n");
        }
        builder.append("bucketRange:\n");
        for (int range : bucketRange) {
            builder.append("  ").append(range).append("\n");
        }
        builder.append("bucketIndex:\n");
        for (int index : bucketIndex) {
            builder.append("  ").append(index).append("\n");
        }
        return builder.toString();
    }

    private int partition(int leftIndex, int rightIndex, int pivotIndex) {
        int li = leftIndex;
        int ri = rightIndex - 2;
        int pi = pivotIndex;
        long pivot = EncodingIdMapper.clearCollision(this.dataCache.get(this.tracker.get(pi)));
        this.tracker.swap(pi, rightIndex - 1, 1);
        long left = EncodingIdMapper.clearCollision(this.dataCache.get(this.tracker.get(li)));
        long right = EncodingIdMapper.clearCollision(this.dataCache.get(this.tracker.get(ri)));
        while (li < ri) {
            if (this.comparator.lt(left, pivot)) {
                left = EncodingIdMapper.clearCollision(this.dataCache.get(this.tracker.get(++li)));
                continue;
            }
            if (this.comparator.ge(right, pivot)) {
                right = EncodingIdMapper.clearCollision(this.dataCache.get(this.tracker.get(--ri)));
                continue;
            }
            this.tracker.swap(li, ri, 1);
            long temp = left;
            left = right;
            right = temp;
        }
        int partingIndex = ri;
        if (this.comparator.lt(right, pivot)) {
            ++partingIndex;
        }
        this.tracker.swap(rightIndex - 1, partingIndex, 1);
        return partingIndex;
    }

    private void recursiveQsort(int start, int end, Random random, SortWorker workerProgress) {
        int diff = end - start;
        if (diff < 2) {
            workerProgress.incrementProgress(diff);
            return;
        }
        workerProgress.incrementProgress(1);
        int pivot = start + random.nextInt(diff);
        pivot = this.partition(start, end, pivot);
        this.recursiveQsort(start, pivot, random, workerProgress);
        this.recursiveQsort(pivot + 1, end, random, workerProgress);
    }

    private class TrackerInitializer
    implements Runnable {
        private final int[] rangeParams;
        private final int lowRadixRange;
        private final int highRadixRange;
        private final int threadIndex;
        private int bucketIndex;
        private final long[] result;

        TrackerInitializer(int threadIndex, int[] rangeParams, int lowRadixRange, int highRadixRange, long[] result) {
            this.threadIndex = threadIndex;
            this.rangeParams = rangeParams;
            this.lowRadixRange = lowRadixRange;
            this.highRadixRange = highRadixRange;
            this.result = result;
        }

        @Override
        public void run() {
            for (long i = 0L; i <= ParallelSort.this.highestSetIndex; ++i) {
                int rIndex = ParallelSort.this.radixCalculator.radixOf(ParallelSort.this.dataCache.get(i));
                if (rIndex <= this.lowRadixRange || rIndex > this.highRadixRange) continue;
                long trackerIndex = this.rangeParams[0] + this.bucketIndex++;
                assert (ParallelSort.this.tracker.get(trackerIndex) == -1) : "Overlapping buckets i:" + i + ", k:" + this.threadIndex + ", index:" + trackerIndex;
                ParallelSort.this.tracker.set(trackerIndex, (int)i);
                if (this.bucketIndex != this.rangeParams[1]) continue;
                this.result[0] = this.highRadixRange;
                this.result[1] = this.rangeParams[0];
            }
        }
    }

    private class SortWorker
    implements Runnable {
        private final int start;
        private final int size;
        private final CountDownLatch waitSignal;
        private int threadLocalProgress;

        SortWorker(int startRange, int size, CountDownLatch wait) {
            this.start = startRange;
            this.size = size;
            this.waitSignal = wait;
        }

        void incrementProgress(int diff) {
            this.threadLocalProgress += diff;
            if (this.threadLocalProgress == 10000) {
                this.reportProgress();
            }
        }

        private void reportProgress() {
            ParallelSort.this.progress.add(this.threadLocalProgress);
            this.threadLocalProgress = 0;
        }

        @Override
        public void run() {
            ThreadLocalRandom random = ThreadLocalRandom.current();
            try {
                this.waitSignal.await();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            ParallelSort.this.recursiveQsort(this.start, this.start + this.size, random, this);
            this.reportProgress();
        }
    }

    public static interface Comparator {
        public boolean lt(long var1, long var3);

        public boolean ge(long var1, long var3);
    }
}

