package io.trino.operator.output;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.execution.resourcegroups.IndexedPriorityQueue;
import io.trino.operator.PartitionFunction;
import io.trino.spi.type.Type;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.SystemPartitioningHandle;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.stream.IntStream;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

@ThreadSafe
/* loaded from: input_file:io/trino/operator/output/SkewedPartitionRebalancer.class */
public class SkewedPartitionRebalancer {
    private static final int SCALE_WRITERS_PARTITION_COUNT = 4096;
    private static final double TASK_BUCKET_SKEWNESS_THRESHOLD = 0.7d;
    private final int partitionCount;
    private final int taskCount;
    private final int taskBucketCount;
    private final long minPartitionDataProcessedRebalanceThreshold;
    private final long minDataProcessedRebalanceThreshold;
    private final AtomicLongArray partitionRowCount;
    private final AtomicLong dataProcessed = new AtomicLong();
    private final AtomicLong dataProcessedAtLastRebalance = new AtomicLong();

    @GuardedBy("this")
    private final long[] partitionDataSizeAtLastRebalance;

    @GuardedBy("this")
    private final long[] partitionDataSizeSinceLastRebalancePerTask;

    @GuardedBy("this")
    private final long[] estimatedTaskBucketDataSizeSinceLastRebalance;
    private final List<List<TaskBucket>> partitionAssignments;
    private static final Logger log = Logger.get(SkewedPartitionRebalancer.class);
    private static final long MIN_DATA_PROCESSED_REBALANCE_THRESHOLD = DataSize.of(50, DataSize.Unit.MEGABYTE).toBytes();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/output/SkewedPartitionRebalancer$TaskBucket.class */
    public final class TaskBucket {
        private final int taskId;
        private final int id;

        private TaskBucket(int i, int i2) {
            this.taskId = i;
            this.id = (i * SkewedPartitionRebalancer.this.taskBucketCount) + i2;
        }

        public int hashCode() {
            return Objects.hash(Integer.valueOf(this.id), Integer.valueOf(this.id));
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            return obj != null && getClass() == obj.getClass() && ((TaskBucket) obj).id == this.id;
        }
    }

    public static boolean checkCanScalePartitionsRemotely(Session session, int i, PartitioningHandle partitioningHandle, NodePartitioningManager nodePartitioningManager) {
        return i > 1 && !((Boolean) partitioningHandle.getCatalogHandle().map(catalogHandle -> {
            return (Boolean) nodePartitioningManager.getConnectorBucketNodeMap(session, partitioningHandle).map((v0) -> {
                return v0.hasFixedMapping();
            }).orElse(false);
        }).orElse(false)).booleanValue() && PartitioningHandle.isScaledWriterHashDistribution(partitioningHandle);
    }

    public static PartitionFunction createPartitionFunction(Session session, NodePartitioningManager nodePartitioningManager, PartitioningScheme partitioningScheme, List<Type> list) {
        PartitioningHandle handle = partitioningScheme.getPartitioning().getHandle();
        return nodePartitioningManager.getPartitionFunction(session, partitioningScheme, list, IntStream.range(0, handle.getConnectorHandle() instanceof SystemPartitioningHandle ? SCALE_WRITERS_PARTITION_COUNT : nodePartitioningManager.getBucketNodeMap(session, handle).getBucketCount()).toArray());
    }

    public static SkewedPartitionRebalancer createSkewedPartitionRebalancer(int i, int i2, int i3, long j) {
        return new SkewedPartitionRebalancer(i, i2, (int) Math.ceil(0.5d * i3), j);
    }

    public static int getTaskCount(PartitioningScheme partitioningScheme) {
        return IntStream.of(partitioningScheme.getBucketToPartition().orElseThrow(() -> {
            return new IllegalArgumentException("Bucket to partition must be set before calculating taskCount");
        })).max().getAsInt() + 1;
    }

    private SkewedPartitionRebalancer(int i, int i2, int i3, long j) {
        this.partitionCount = i;
        this.taskCount = i2;
        this.taskBucketCount = i3;
        this.minPartitionDataProcessedRebalanceThreshold = j;
        this.minDataProcessedRebalanceThreshold = Math.max(j, MIN_DATA_PROCESSED_REBALANCE_THRESHOLD);
        this.partitionRowCount = new AtomicLongArray(i);
        this.partitionDataSizeAtLastRebalance = new long[i];
        this.partitionDataSizeSinceLastRebalancePerTask = new long[i];
        this.estimatedTaskBucketDataSizeSinceLastRebalance = new long[i2 * i3];
        int[] iArr = new int[i2];
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i4 = 0; i4 < i; i4++) {
            int i5 = i4 % i2;
            int i6 = iArr[i5];
            iArr[i5] = i6 + 1;
            builder.add(new CopyOnWriteArrayList((Collection) ImmutableList.of(new TaskBucket(i5, i6 % i3))));
        }
        this.partitionAssignments = builder.build();
    }

    @VisibleForTesting
    List<List<Integer>> getPartitionAssignments() {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<List<TaskBucket>> it = this.partitionAssignments.iterator();
        while (it.hasNext()) {
            builder.add((List) it.next().stream().map(taskBucket -> {
                return Integer.valueOf(taskBucket.taskId);
            }).collect(ImmutableList.toImmutableList()));
        }
        return builder.build();
    }

    public int getTaskCount() {
        return this.taskCount;
    }

    public int getTaskId(int i, long j) {
        List<TaskBucket> list = this.partitionAssignments.get(i);
        return list.get(Math.floorMod(j, list.size())).taskId;
    }

    public void addDataProcessed(long j) {
        this.dataProcessed.addAndGet(j);
    }

    public void addPartitionRowCount(int i, long j) {
        this.partitionRowCount.addAndGet(i, j);
    }

    public void rebalance() {
        long j = this.dataProcessed.get();
        if (shouldRebalance(j)) {
            rebalancePartitions(j);
        }
    }

    private boolean shouldRebalance(long j) {
        return j - this.dataProcessedAtLastRebalance.get() >= this.minDataProcessedRebalanceThreshold;
    }

    private synchronized void rebalancePartitions(long j) {
        if (shouldRebalance(j)) {
            long[] calculatePartitionDataSize = calculatePartitionDataSize(j);
            for (int i = 0; i < this.partitionCount; i++) {
                this.partitionDataSizeSinceLastRebalancePerTask[i] = (calculatePartitionDataSize[i] - this.partitionDataSizeAtLastRebalance[i]) / this.partitionAssignments.get(i).size();
            }
            ArrayList arrayList = new ArrayList(this.taskCount * this.taskBucketCount);
            for (int i2 = 0; i2 < this.taskCount; i2++) {
                for (int i3 = 0; i3 < this.taskBucketCount; i3++) {
                    arrayList.add(new IndexedPriorityQueue<>());
                }
            }
            for (int i4 = 0; i4 < this.partitionCount; i4++) {
                Iterator<TaskBucket> it = this.partitionAssignments.get(i4).iterator();
                while (it.hasNext()) {
                    arrayList.get(it.next().id).addOrUpdate(Integer.valueOf(i4), this.partitionDataSizeSinceLastRebalancePerTask[i4]);
                }
            }
            IndexedPriorityQueue<TaskBucket> indexedPriorityQueue = new IndexedPriorityQueue<>();
            IndexedPriorityQueue<TaskBucket> indexedPriorityQueue2 = new IndexedPriorityQueue<>();
            for (int i5 = 0; i5 < this.taskCount; i5++) {
                for (int i6 = 0; i6 < this.taskBucketCount; i6++) {
                    TaskBucket taskBucket = new TaskBucket(i5, i6);
                    this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id] = calculateTaskBucketDataSizeSinceLastRebalance(arrayList.get(taskBucket.id));
                    indexedPriorityQueue.addOrUpdate(taskBucket, this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
                    indexedPriorityQueue2.addOrUpdate(taskBucket, Long.MAX_VALUE - this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id]);
                }
            }
            rebalanceBasedOnTaskBucketSkewness(indexedPriorityQueue, indexedPriorityQueue2, arrayList, calculatePartitionDataSize);
            this.dataProcessedAtLastRebalance.set(j);
        }
    }

    private long[] calculatePartitionDataSize(long j) {
        long j2 = 0;
        for (int i = 0; i < this.partitionCount; i++) {
            j2 += this.partitionRowCount.get(i);
        }
        long[] jArr = new long[this.partitionCount];
        for (int i2 = 0; i2 < this.partitionCount; i2++) {
            jArr[i2] = (this.partitionRowCount.get(i2) * j) / j2;
        }
        return jArr;
    }

    private long calculateTaskBucketDataSizeSinceLastRebalance(IndexedPriorityQueue<Integer> indexedPriorityQueue) {
        long j = 0;
        Iterator<Integer> it = indexedPriorityQueue.iterator();
        while (it.hasNext()) {
            j += this.partitionDataSizeSinceLastRebalancePerTask[it.next().intValue()];
        }
        return j;
    }

    private void rebalanceBasedOnTaskBucketSkewness(IndexedPriorityQueue<TaskBucket> indexedPriorityQueue, IndexedPriorityQueue<TaskBucket> indexedPriorityQueue2, List<IndexedPriorityQueue<Integer>> list, long[] jArr) {
        while (true) {
            TaskBucket poll = indexedPriorityQueue.poll();
            if (poll == null) {
                return;
            }
            IndexedPriorityQueue<Integer> indexedPriorityQueue3 = list.get(poll.id);
            if (!indexedPriorityQueue3.isEmpty()) {
                List<TaskBucket> findSkewedMinTaskBuckets = findSkewedMinTaskBuckets(poll, indexedPriorityQueue2);
                if (findSkewedMinTaskBuckets.isEmpty()) {
                    return;
                }
                while (true) {
                    Integer poll2 = indexedPriorityQueue3.poll();
                    if (poll2 != null && this.partitionDataSizeSinceLastRebalancePerTask[poll2.intValue()] >= this.minPartitionDataProcessedRebalanceThreshold) {
                        Iterator<TaskBucket> it = findSkewedMinTaskBuckets.iterator();
                        while (it.hasNext()) {
                            if (rebalancePartition(poll2.intValue(), it.next(), indexedPriorityQueue, indexedPriorityQueue2, jArr[poll2.intValue()])) {
                                break;
                            }
                        }
                    }
                }
            }
        }
    }

    private List<TaskBucket> findSkewedMinTaskBuckets(TaskBucket taskBucket, IndexedPriorityQueue<TaskBucket> indexedPriorityQueue) {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<TaskBucket> it = indexedPriorityQueue.iterator();
        while (it.hasNext()) {
            TaskBucket next = it.next();
            double d = (this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id] - this.estimatedTaskBucketDataSizeSinceLastRebalance[next.id]) / this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket.id];
            if (d <= TASK_BUCKET_SKEWNESS_THRESHOLD || Double.isNaN(d)) {
                break;
            }
            if (taskBucket.taskId != next.taskId) {
                builder.add(next);
            }
        }
        return builder.build();
    }

    private boolean rebalancePartition(int i, TaskBucket taskBucket, IndexedPriorityQueue<TaskBucket> indexedPriorityQueue, IndexedPriorityQueue<TaskBucket> indexedPriorityQueue2, long j) {
        List<TaskBucket> list = this.partitionAssignments.get(i);
        if (list.stream().anyMatch(taskBucket2 -> {
            return taskBucket2.taskId == taskBucket.taskId;
        })) {
            return false;
        }
        list.add(taskBucket);
        this.partitionDataSizeAtLastRebalance[i] = j;
        int size = list.size();
        int i2 = size - 1;
        for (TaskBucket taskBucket3 : list) {
            if (taskBucket3.equals(taskBucket)) {
                long[] jArr = this.estimatedTaskBucketDataSizeSinceLastRebalance;
                int i3 = taskBucket3.id;
                jArr[i3] = jArr[i3] + ((this.partitionDataSizeSinceLastRebalancePerTask[i] * i2) / size);
            } else {
                long[] jArr2 = this.estimatedTaskBucketDataSizeSinceLastRebalance;
                int i4 = taskBucket3.id;
                jArr2[i4] = jArr2[i4] - (this.partitionDataSizeSinceLastRebalancePerTask[i] / size);
            }
            indexedPriorityQueue.addOrUpdate(taskBucket3, this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket3.id]);
            indexedPriorityQueue2.addOrUpdate(taskBucket3, Long.MAX_VALUE - this.estimatedTaskBucketDataSizeSinceLastRebalance[taskBucket3.id]);
        }
        log.warn("Rebalanced partition %s to task %s with taskCount %s", new Object[]{Integer.valueOf(i), Integer.valueOf(taskBucket.taskId), Integer.valueOf(list.size())});
        return true;
    }
}
