/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.output;

import io.trino.operator.PartitionFunction;
import io.trino.operator.output.SkewedPartitionRebalancer;
import io.trino.spi.Page;
import java.util.Objects;

public class SkewedPartitionFunction
implements PartitionFunction {
    private final PartitionFunction partitionFunction;
    private final SkewedPartitionRebalancer skewedPartitionRebalancer;
    private final long[] partitionRowCount;

    public SkewedPartitionFunction(PartitionFunction partitionFunction, SkewedPartitionRebalancer skewedPartitionRebalancer) {
        this.partitionFunction = Objects.requireNonNull(partitionFunction, "partitionFunction is null");
        this.skewedPartitionRebalancer = Objects.requireNonNull(skewedPartitionRebalancer, "skewedPartitionRebalancer is null");
        this.partitionRowCount = new long[partitionFunction.partitionCount()];
    }

    @Override
    public int partitionCount() {
        return this.skewedPartitionRebalancer.getTaskCount();
    }

    @Override
    public int getPartition(Page page, int position) {
        int partition;
        int n = partition = this.partitionFunction.getPartition(page, position);
        long l = this.partitionRowCount[n];
        this.partitionRowCount[n] = l + 1L;
        return this.skewedPartitionRebalancer.getTaskId(partition, l);
    }

    public void flushPartitionRowCountToRebalancer() {
        for (int partition = 0; partition < this.partitionFunction.partitionCount(); ++partition) {
            this.skewedPartitionRebalancer.addPartitionRowCount(partition, this.partitionRowCount[partition]);
            this.partitionRowCount[partition] = 0L;
        }
    }
}

