package io.prestosql.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.airlift.units.DataSize;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.block.BlockBuilderStatus;
import io.prestosql.spi.type.Type;
import io.prestosql.type.TypeUtils;
import io.prestosql.util.Failures;
import it.unimi.dsi.fastutil.HashCommon;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.Objects;
import java.util.Optional;
import org.openjdk.jol.info.ClassLayout;

/* loaded from: input_file:io/prestosql/operator/aggregation/TypedSet.class */
public class TypedSet {

    @VisibleForTesting
    public static final DataSize MAX_FUNCTION_MEMORY = DataSize.of(4, DataSize.Unit.MEGABYTE);
    private static final int INSTANCE_SIZE = ClassLayout.parseClass(TypedSet.class).instanceSize();
    private static final int INT_ARRAY_LIST_INSTANCE_SIZE = ClassLayout.parseClass(IntArrayList.class).instanceSize();
    private static final float FILL_RATIO = 0.75f;
    private final Type elementType;
    private final Optional<MethodHandle> elementIsDistinctFrom;
    private final IntArrayList blockPositionByHash;
    private final BlockBuilder elementBlock;
    private final String functionName;
    private final long maxBlockMemoryInBytes;
    private int initialElementBlockOffset;
    private long initialElementBlockSizeInBytes;
    private int size;
    private int hashCapacity;
    private int maxFill;
    private int hashMask;
    private static final int EMPTY_SLOT = -1;
    private boolean containsNullElement;

    public TypedSet(Type type, int i, String str) {
        this(type, Optional.empty(), type.createBlockBuilder((BlockBuilderStatus) null, i), i, str);
    }

    public TypedSet(Type type, MethodHandle methodHandle, int i, String str) {
        this(type, Optional.of(methodHandle), type.createBlockBuilder((BlockBuilderStatus) null, i), i, str);
    }

    public TypedSet(Type type, Optional<MethodHandle> optional, BlockBuilder blockBuilder, int i, String str) {
        this(type, optional, blockBuilder, i, str, Optional.of(MAX_FUNCTION_MEMORY));
    }

    public TypedSet(Type type, Optional<MethodHandle> optional, BlockBuilder blockBuilder, int i, String str, Optional<DataSize> optional2) {
        Preconditions.checkArgument(i >= 0, "expectedSize must not be negative");
        this.elementType = (Type) Objects.requireNonNull(type, "elementType must not be null");
        this.elementIsDistinctFrom = (Optional) Objects.requireNonNull(optional, "elementIsDistinctFrom is null");
        optional.ifPresent(methodHandle -> {
            Preconditions.checkArgument(methodHandle.type().equals(MethodType.methodType(Boolean.TYPE, Block.class, Integer.TYPE, Block.class, Integer.TYPE)));
        });
        this.elementBlock = (BlockBuilder) Objects.requireNonNull(blockBuilder, "blockBuilder must not be null");
        this.functionName = str;
        this.maxBlockMemoryInBytes = ((Long) optional2.map((v0) -> {
            return v0.toBytes();
        }).orElse(Long.MAX_VALUE)).longValue();
        this.initialElementBlockOffset = this.elementBlock.getPositionCount();
        this.initialElementBlockSizeInBytes = this.elementBlock.getSizeInBytes();
        this.size = 0;
        this.hashCapacity = HashCommon.arraySize(i, FILL_RATIO);
        this.maxFill = calculateMaxFill(this.hashCapacity);
        this.hashMask = this.hashCapacity - 1;
        this.blockPositionByHash = new IntArrayList(this.hashCapacity);
        this.blockPositionByHash.size(this.hashCapacity);
        for (int i2 = 0; i2 < this.hashCapacity; i2++) {
            this.blockPositionByHash.set(i2, -1);
        }
        this.containsNullElement = false;
    }

    public long getRetainedSizeInBytes() {
        return INSTANCE_SIZE + INT_ARRAY_LIST_INSTANCE_SIZE + this.elementBlock.getRetainedSizeInBytes() + (this.blockPositionByHash.size() * 4);
    }

    public boolean contains(Block block, int i) {
        Objects.requireNonNull(block, "block must not be null");
        Preconditions.checkArgument(i >= 0, "position must be >= 0");
        return block.isNull(i) ? this.containsNullElement : this.blockPositionByHash.getInt(getHashPositionOfElement(block, i)) != -1;
    }

    public void add(Block block, int i) {
        Objects.requireNonNull(block, "block must not be null");
        Preconditions.checkArgument(i >= 0, "position must be >= 0");
        if (block.isNull(i)) {
            this.containsNullElement = true;
        }
        int hashPositionOfElement = getHashPositionOfElement(block, i);
        if (this.blockPositionByHash.getInt(hashPositionOfElement) == -1) {
            addNewElement(hashPositionOfElement, block, i);
        }
    }

    public int size() {
        return this.size;
    }

    public int positionOf(Block block, int i) {
        return this.blockPositionByHash.getInt(getHashPositionOfElement(block, i));
    }

    private int getHashPositionOfElement(Block block, int i) {
        int maskedHash = getMaskedHash(TypeUtils.hashPosition(this.elementType, block, i));
        while (true) {
            int i2 = maskedHash;
            int i3 = this.blockPositionByHash.getInt(i2);
            if (i3 != -1 && !isContainedAt(block, i, i3)) {
                maskedHash = getMaskedHash(i2 + 1);
            }
            return i2;
        }
    }

    private boolean isContainedAt(Block block, int i, int i2) {
        if (!this.elementIsDistinctFrom.isPresent()) {
            return TypeUtils.positionEqualsPosition(this.elementType, this.elementBlock, i2, block, i);
        }
        try {
            return !(boolean) this.elementIsDistinctFrom.get().invokeExact(this.elementBlock, i2, block, i);
        } catch (Throwable th) {
            throw Failures.internalError(th);
        }
    }

    private void addNewElement(int i, Block block, int i2) {
        this.elementType.appendTo(block, i2, this.elementBlock);
        if (this.elementBlock.getSizeInBytes() - this.initialElementBlockSizeInBytes > this.maxBlockMemoryInBytes) {
            throw new PrestoException(StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT, String.format("The input to %s is too large. More than %s of memory is needed to hold the intermediate hash set.\n", this.functionName, MAX_FUNCTION_MEMORY));
        }
        this.blockPositionByHash.set(i, this.elementBlock.getPositionCount() - 1);
        this.size++;
        if (this.size >= this.maxFill) {
            rehash();
        }
    }

    private void rehash() {
        long j = this.hashCapacity * 2;
        if (j > 2147483647L) {
            throw new PrestoException(StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES, "Size of hash table cannot exceed 1 billion entries");
        }
        int i = (int) j;
        this.hashCapacity = i;
        this.hashMask = i - 1;
        this.maxFill = calculateMaxFill(i);
        this.blockPositionByHash.size(i);
        for (int i2 = 0; i2 < i; i2++) {
            this.blockPositionByHash.set(i2, -1);
        }
        for (int i3 = this.initialElementBlockOffset; i3 < this.elementBlock.getPositionCount(); i3++) {
            this.blockPositionByHash.set(getHashPositionOfElement(this.elementBlock, i3), i3);
        }
    }

    private static int calculateMaxFill(int i) {
        Preconditions.checkArgument(i > 0, "hashSize must be greater than 0");
        int ceil = (int) Math.ceil(i * FILL_RATIO);
        if (ceil == i) {
            ceil--;
        }
        Preconditions.checkArgument(i > ceil, "hashSize must be larger than maxFill");
        return ceil;
    }

    private int getMaskedHash(long j) {
        return (int) (j & this.hashMask);
    }
}
