package io.trino.sql.planner;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.units.DataSize;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.Type;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.annotation.concurrent.GuardedBy;

/* loaded from: input_file:io/trino/sql/planner/LocalDynamicFilterConsumer.class */
public class LocalDynamicFilterConsumer implements DynamicFilterSourceConsumer {
    private final Map<DynamicFilterId, Integer> buildChannels;
    private final Map<DynamicFilterId, Type> filterBuildTypes;
    private final List<Consumer<Map<DynamicFilterId, Domain>>> collectors;
    private final long domainSizeLimitInBytes;

    @GuardedBy("this")
    private Integer expectedPartitionCount;

    @GuardedBy("this")
    private int collectedPartitionCount;

    @GuardedBy("this")
    private volatile boolean collected;
    private final Queue<TupleDomain<DynamicFilterId>> summaryDomains = new ConcurrentLinkedQueue();
    private final AtomicLong summaryDomainsRetainedSizeInBytes = new AtomicLong();

    public LocalDynamicFilterConsumer(Map<DynamicFilterId, Integer> map, Map<DynamicFilterId, Type> map2, List<Consumer<Map<DynamicFilterId, Domain>>> list, DataSize dataSize) {
        this.buildChannels = (Map) Objects.requireNonNull(map, "buildChannels is null");
        this.filterBuildTypes = (Map) Objects.requireNonNull(map2, "filterBuildTypes is null");
        Verify.verify(map.keySet().equals(map2.keySet()), "filterBuildTypes and buildChannels must have same keys", new Object[0]);
        Objects.requireNonNull(list, "collectors is null");
        Preconditions.checkArgument(!list.isEmpty(), "collectors is empty");
        this.collectors = ImmutableList.copyOf(list);
        this.domainSizeLimitInBytes = dataSize.toBytes();
    }

    @Override // io.trino.sql.planner.DynamicFilterSourceConsumer
    public void addPartition(TupleDomain<DynamicFilterId> tupleDomain) {
        TupleDomain<DynamicFilterId> all;
        if (this.collected) {
            return;
        }
        this.summaryDomainsRetainedSizeInBytes.addAndGet(getRetainedSizeInBytes(tupleDomain));
        this.summaryDomains.add(tupleDomain);
        unionSummaryDomainsIfNecessary(false);
        synchronized (this) {
            Verify.verify(this.expectedPartitionCount == null || this.collectedPartitionCount < this.expectedPartitionCount.intValue());
            if (this.collected) {
                clearSummaryDomains();
                return;
            }
            this.collectedPartitionCount++;
            boolean z = this.expectedPartitionCount != null && this.collectedPartitionCount == this.expectedPartitionCount.intValue();
            if (z) {
                unionSummaryDomainsIfNecessary(true);
            }
            boolean z2 = false;
            TupleDomain<DynamicFilterId> poll = this.summaryDomains.poll();
            if (poll != null) {
                long retainedSizeInBytes = getRetainedSizeInBytes(poll);
                if (retainedSizeInBytes > this.domainSizeLimitInBytes) {
                    poll = poll.simplify(1);
                }
                if (getRetainedSizeInBytes(poll) > this.domainSizeLimitInBytes) {
                    this.summaryDomainsRetainedSizeInBytes.addAndGet(-retainedSizeInBytes);
                    z2 = true;
                } else {
                    this.summaryDomainsRetainedSizeInBytes.addAndGet(getRetainedSizeInBytes(poll) - retainedSizeInBytes);
                    this.summaryDomains.add(poll);
                }
            }
            if (z || z2 || tupleDomain.isAll()) {
                if (z2 || tupleDomain.isAll()) {
                    clearSummaryDomains();
                    all = TupleDomain.all();
                } else {
                    Verify.verify(this.expectedPartitionCount != null && this.collectedPartitionCount == this.expectedPartitionCount.intValue());
                    Verify.verify(this.summaryDomains.size() == 1);
                    all = this.summaryDomains.poll();
                    Verify.verify(all != null);
                    long addAndGet = this.summaryDomainsRetainedSizeInBytes.addAndGet(-getRetainedSizeInBytes(all));
                    Verify.verify(addAndGet == 0, "currentSize is expected to be zero: %s", addAndGet);
                }
                this.collected = true;
                TupleDomain<DynamicFilterId> tupleDomain2 = all;
                this.collectors.forEach(consumer -> {
                    consumer.accept(convertTupleDomain(tupleDomain2));
                });
            }
        }
    }

    @Override // io.trino.sql.planner.DynamicFilterSourceConsumer
    public void setPartitionCount(int i) {
        TupleDomain<DynamicFilterId> poll;
        synchronized (this) {
            if (this.collected) {
                return;
            }
            Preconditions.checkState(this.expectedPartitionCount == null, "setPartitionCount should be called only once");
            this.expectedPartitionCount = Integer.valueOf(i);
            if (this.collectedPartitionCount < this.expectedPartitionCount.intValue()) {
                return;
            }
            if (i == 0) {
                poll = TupleDomain.none();
            } else {
                unionSummaryDomainsIfNecessary(true);
                Verify.verify(this.summaryDomains.size() == 1);
                poll = this.summaryDomains.poll();
                Verify.verify(poll != null);
                long addAndGet = this.summaryDomainsRetainedSizeInBytes.addAndGet(-getRetainedSizeInBytes(poll));
                Verify.verify(addAndGet == 0, "currentSize is expected to be zero: %s", addAndGet);
            }
            this.collected = true;
            TupleDomain<DynamicFilterId> tupleDomain = poll;
            this.collectors.forEach(consumer -> {
                consumer.accept(convertTupleDomain(tupleDomain));
            });
        }
    }

    private void unionSummaryDomainsIfNecessary(boolean z) {
        long j;
        if (this.summaryDomainsRetainedSizeInBytes.get() >= this.domainSizeLimitInBytes || z) {
            ArrayList arrayList = new ArrayList();
            long j2 = 0;
            while (true) {
                j = j2;
                TupleDomain<DynamicFilterId> poll = this.summaryDomains.poll();
                if (poll == null) {
                    break;
                }
                arrayList.add(poll);
                j2 = j + getRetainedSizeInBytes(poll);
            }
            if (arrayList.isEmpty()) {
                return;
            }
            TupleDomain<DynamicFilterId> columnWiseUnion = TupleDomain.columnWiseUnion(arrayList);
            this.summaryDomainsRetainedSizeInBytes.addAndGet(getRetainedSizeInBytes(columnWiseUnion) - j);
            long j3 = this.summaryDomainsRetainedSizeInBytes.get();
            Verify.verify(j3 >= 0, "currentSize is expected to be greater than or equal to zero: %s", j3);
            this.summaryDomains.add(columnWiseUnion);
        }
    }

    @Override // io.trino.sql.planner.DynamicFilterSourceConsumer
    public synchronized boolean isDomainCollectionComplete() {
        return this.collected;
    }

    private void clearSummaryDomains() {
        long j;
        long j2 = 0;
        while (true) {
            j = j2;
            TupleDomain<DynamicFilterId> poll = this.summaryDomains.poll();
            if (poll == null) {
                break;
            } else {
                j2 = j + getRetainedSizeInBytes(poll);
            }
        }
        this.summaryDomainsRetainedSizeInBytes.addAndGet(-j);
        long j3 = this.summaryDomainsRetainedSizeInBytes.get();
        Verify.verify(j3 >= 0, "currentSize is expected to be greater than or equal to zero: %s", j3);
    }

    private Map<DynamicFilterId, Domain> convertTupleDomain(TupleDomain<DynamicFilterId> tupleDomain) {
        if (tupleDomain.isNone()) {
            return (Map) this.buildChannels.keySet().stream().collect(ImmutableMap.toImmutableMap(Function.identity(), dynamicFilterId -> {
                return Domain.none(this.filterBuildTypes.get(dynamicFilterId));
            }));
        }
        HashMap hashMap = new HashMap((Map) tupleDomain.getDomains().get());
        this.buildChannels.keySet().forEach(dynamicFilterId2 -> {
            hashMap.putIfAbsent(dynamicFilterId2, Domain.all(this.filterBuildTypes.get(dynamicFilterId2)));
        });
        return ImmutableMap.copyOf(hashMap);
    }

    public static LocalDynamicFilterConsumer create(JoinNode joinNode, List<Type> list, Set<DynamicFilterId> set, List<Consumer<Map<DynamicFilterId, Domain>>> list2, DataSize dataSize) {
        Preconditions.checkArgument(!joinNode.getDynamicFilters().isEmpty(), "Join node dynamicFilters is empty.");
        Preconditions.checkArgument(!set.isEmpty(), "Collected dynamic filters set is empty");
        Preconditions.checkArgument(joinNode.getDynamicFilters().keySet().containsAll(set), "Collected dynamic filters set is not subset of join dynamic filters");
        PlanNode right = joinNode.getRight();
        Map map = (Map) joinNode.getDynamicFilters().entrySet().stream().filter(entry -> {
            return set.contains(entry.getKey());
        }).collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry2 -> {
            int indexOf = right.getOutputSymbols().indexOf((Symbol) entry2.getValue());
            Verify.verify(indexOf >= 0);
            return Integer.valueOf(indexOf);
        }));
        return new LocalDynamicFilterConsumer(map, (Map) map.entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry3 -> {
            return (Type) list.get(((Integer) entry3.getValue()).intValue());
        })), list2, dataSize);
    }

    public Map<DynamicFilterId, Integer> getBuildChannels() {
        return this.buildChannels;
    }

    public synchronized String toString() {
        return MoreObjects.toStringHelper(this).add("buildChannels", this.buildChannels).add("filterBuildTypes", this.filterBuildTypes).add("domainSizeLimitInBytes", this.domainSizeLimitInBytes).add("expectedPartitionCount", this.expectedPartitionCount).add("collectedPartitionCount", this.collectedPartitionCount).add("collected", this.collected).add("summaryDomains", this.summaryDomains).add("summaryDomainsRetainedSizeInBytes", this.summaryDomainsRetainedSizeInBytes).toString();
    }

    private static long getRetainedSizeInBytes(TupleDomain<DynamicFilterId> tupleDomain) {
        return tupleDomain.getRetainedSizeInBytes((v0) -> {
            return v0.getRetainedSizeInBytes();
        });
    }
}
