package io.prestosql.memory;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import io.airlift.stats.GcMonitor;
import io.airlift.units.DataSize;
import io.prestosql.ExceededMemoryLimitException;
import io.prestosql.ExceededSpillLimitException;
import io.prestosql.Session;
import io.prestosql.execution.TaskId;
import io.prestosql.execution.TaskStateMachine;
import io.prestosql.memory.context.AggregatedMemoryContext;
import io.prestosql.memory.context.MemoryReservationHandler;
import io.prestosql.memory.context.MemoryTrackingContext;
import io.prestosql.operator.Operator;
import io.prestosql.operator.TaskContext;
import io.prestosql.spi.QueryId;
import io.prestosql.spiller.SpillSpaceTracker;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

@ThreadSafe
/* loaded from: input_file:io/prestosql/memory/QueryContext.class */
public class QueryContext {
    private static final long GUARANTEED_MEMORY = DataSize.of(1, DataSize.Unit.MEGABYTE).toBytes();
    private final QueryId queryId;
    private final GcMonitor gcMonitor;
    private final Executor notificationExecutor;
    private final ScheduledExecutorService yieldExecutor;
    private final long maxSpill;
    private final SpillSpaceTracker spillSpaceTracker;

    @GuardedBy("this")
    private long maxUserMemory;

    @GuardedBy("this")
    private long maxTotalMemory;

    @GuardedBy("this")
    private MemoryPool memoryPool;

    @GuardedBy("this")
    private long spillUsed;
    private final Map<TaskId, TaskContext> taskContexts = new ConcurrentHashMap();
    private final MemoryTrackingContext queryMemoryContext = new MemoryTrackingContext(AggregatedMemoryContext.newRootAggregatedMemoryContext(new QueryMemoryReservationHandler((v1, v2) -> {
        return updateUserMemory(v1, v2);
    }, (v1, v2) -> {
        return tryUpdateUserMemory(v1, v2);
    }), GUARANTEED_MEMORY), AggregatedMemoryContext.newRootAggregatedMemoryContext(new QueryMemoryReservationHandler((v1, v2) -> {
        return updateRevocableMemory(v1, v2);
    }, (v1, v2) -> {
        return tryReserveMemoryNotSupported(v1, v2);
    }), 0), AggregatedMemoryContext.newRootAggregatedMemoryContext(new QueryMemoryReservationHandler((v1, v2) -> {
        return updateSystemMemory(v1, v2);
    }, (v1, v2) -> {
        return tryReserveMemoryNotSupported(v1, v2);
    }), 0));

    /* loaded from: input_file:io/prestosql/memory/QueryContext$QueryMemoryReservationHandler.class */
    private static class QueryMemoryReservationHandler implements MemoryReservationHandler {
        private final BiFunction<String, Long, ListenableFuture<?>> reserveMemoryFunction;
        private final BiPredicate<String, Long> tryReserveMemoryFunction;

        public QueryMemoryReservationHandler(BiFunction<String, Long, ListenableFuture<?>> biFunction, BiPredicate<String, Long> biPredicate) {
            this.reserveMemoryFunction = (BiFunction) Objects.requireNonNull(biFunction, "reserveMemoryFunction is null");
            this.tryReserveMemoryFunction = (BiPredicate) Objects.requireNonNull(biPredicate, "tryReserveMemoryFunction is null");
        }

        public ListenableFuture<?> reserveMemory(String str, long j) {
            return this.reserveMemoryFunction.apply(str, Long.valueOf(j));
        }

        public boolean tryReserveMemory(String str, long j) {
            return this.tryReserveMemoryFunction.test(str, Long.valueOf(j));
        }
    }

    public QueryContext(QueryId queryId, DataSize dataSize, DataSize dataSize2, MemoryPool memoryPool, GcMonitor gcMonitor, Executor executor, ScheduledExecutorService scheduledExecutorService, DataSize dataSize3, SpillSpaceTracker spillSpaceTracker) {
        this.queryId = (QueryId) Objects.requireNonNull(queryId, "queryId is null");
        this.maxUserMemory = ((DataSize) Objects.requireNonNull(dataSize, "maxUserMemory is null")).toBytes();
        this.maxTotalMemory = ((DataSize) Objects.requireNonNull(dataSize2, "maxTotalMemory is null")).toBytes();
        this.memoryPool = (MemoryPool) Objects.requireNonNull(memoryPool, "memoryPool is null");
        this.gcMonitor = (GcMonitor) Objects.requireNonNull(gcMonitor, "gcMonitor is null");
        this.notificationExecutor = (Executor) Objects.requireNonNull(executor, "notificationExecutor is null");
        this.yieldExecutor = (ScheduledExecutorService) Objects.requireNonNull(scheduledExecutorService, "yieldExecutor is null");
        this.maxSpill = ((DataSize) Objects.requireNonNull(dataSize3, "maxSpill is null")).toBytes();
        this.spillSpaceTracker = (SpillSpaceTracker) Objects.requireNonNull(spillSpaceTracker, "spillSpaceTracker is null");
    }

    public synchronized void setResourceOvercommit() {
        this.maxUserMemory = this.memoryPool.getMaxBytes();
        this.maxTotalMemory = this.memoryPool.getMaxBytes();
    }

    @VisibleForTesting
    MemoryTrackingContext getQueryMemoryContext() {
        return this.queryMemoryContext;
    }

    private synchronized ListenableFuture<?> updateUserMemory(String str, long j) {
        if (j >= 0) {
            enforceUserMemoryLimit(this.queryMemoryContext.getUserMemory(), j, this.maxUserMemory);
            return this.memoryPool.reserve(this.queryId, str, j);
        }
        this.memoryPool.free(this.queryId, str, -j);
        return Operator.NOT_BLOCKED;
    }

    private synchronized ListenableFuture<?> updateRevocableMemory(String str, long j) {
        if (j >= 0) {
            return this.memoryPool.reserveRevocable(this.queryId, j);
        }
        this.memoryPool.freeRevocable(this.queryId, -j);
        return Operator.NOT_BLOCKED;
    }

    private synchronized ListenableFuture<?> updateSystemMemory(String str, long j) {
        long queryMemoryReservation = this.memoryPool.getQueryMemoryReservation(this.queryId);
        if (j >= 0) {
            enforceTotalMemoryLimit(queryMemoryReservation, j, this.maxTotalMemory);
            return this.memoryPool.reserve(this.queryId, str, j);
        }
        this.memoryPool.free(this.queryId, str, -j);
        return Operator.NOT_BLOCKED;
    }

    public synchronized ListenableFuture<?> reserveSpill(long j) {
        Preconditions.checkArgument(j >= 0, "bytes is negative");
        if (this.spillUsed + j > this.maxSpill) {
            throw ExceededSpillLimitException.exceededPerQueryLocalLimit(DataSize.succinctBytes(this.maxSpill));
        }
        ListenableFuture<?> reserve = this.spillSpaceTracker.reserve(j);
        this.spillUsed += j;
        return reserve;
    }

    private synchronized boolean tryUpdateUserMemory(String str, long j) {
        if (j > 0) {
            if (this.queryMemoryContext.getUserMemory() + j > this.maxUserMemory) {
                return false;
            }
            return this.memoryPool.tryReserve(this.queryId, str, j);
        }
        ListenableFuture<?> updateUserMemory = updateUserMemory(str, j);
        if (j >= 0) {
            return true;
        }
        Verify.verify(updateUserMemory.isDone(), "future should be done", new Object[0]);
        return true;
    }

    public synchronized void freeSpill(long j) {
        Preconditions.checkArgument(this.spillUsed - j >= 0, "tried to free more memory than is reserved");
        this.spillUsed -= j;
        this.spillSpaceTracker.free(j);
    }

    public synchronized void setMemoryPool(MemoryPool memoryPool) {
        Objects.requireNonNull(memoryPool, "newMemoryPool is null");
        if (this.memoryPool == memoryPool) {
            return;
        }
        ListenableFuture<?> moveQuery = this.memoryPool.moveQuery(this.queryId, memoryPool);
        this.memoryPool = memoryPool;
        moveQuery.addListener(() -> {
            this.taskContexts.values().forEach((v0) -> {
                v0.moreMemoryAvailable();
            });
        }, MoreExecutors.directExecutor());
    }

    public synchronized void setMaxUserMemory(long j) {
        this.maxUserMemory = j;
    }

    public synchronized void setMaxTotalMemory(long j) {
        this.maxTotalMemory = j;
    }

    public synchronized MemoryPool getMemoryPool() {
        return this.memoryPool;
    }

    public TaskContext addTaskContext(TaskStateMachine taskStateMachine, Session session, Runnable runnable, boolean z, boolean z2, OptionalInt optionalInt) {
        TaskContext createTaskContext = TaskContext.createTaskContext(this, taskStateMachine, this.gcMonitor, this.notificationExecutor, this.yieldExecutor, session, this.queryMemoryContext.newMemoryTrackingContext(), runnable, z, z2, optionalInt);
        this.taskContexts.put(taskStateMachine.getTaskId(), createTaskContext);
        return createTaskContext;
    }

    public <C, R> R accept(QueryContextVisitor<C, R> queryContextVisitor, C c) {
        return queryContextVisitor.visitQueryContext(this, c);
    }

    public <C, R> List<R> acceptChildren(QueryContextVisitor<C, R> queryContextVisitor, C c) {
        return (List) this.taskContexts.values().stream().map(taskContext -> {
            return taskContext.accept(queryContextVisitor, c);
        }).collect(Collectors.toList());
    }

    public TaskContext getTaskContextByTaskId(TaskId taskId) {
        return (TaskContext) Verify.verifyNotNull(this.taskContexts.get(taskId), "task does not exist", new Object[0]);
    }

    private boolean tryReserveMemoryNotSupported(String str, long j) {
        throw new UnsupportedOperationException("tryReserveMemory is not supported");
    }

    @GuardedBy("this")
    private void enforceUserMemoryLimit(long j, long j2, long j3) {
        if (j + j2 > j3) {
            throw ExceededMemoryLimitException.exceededLocalUserMemoryLimit(DataSize.succinctBytes(j3), getAdditionalFailureInfo(j, j2));
        }
    }

    @GuardedBy("this")
    private void enforceTotalMemoryLimit(long j, long j2, long j3) {
        if (j + j2 > j3) {
            throw ExceededMemoryLimitException.exceededLocalTotalMemoryLimit(DataSize.succinctBytes(j3), getAdditionalFailureInfo(j, j2));
        }
    }

    @GuardedBy("this")
    private String getAdditionalFailureInfo(long j, long j2) {
        Map<String, Long> map = this.memoryPool.getTaggedMemoryAllocations().get(this.queryId);
        String format = String.format("Allocated: %s, Delta: %s", DataSize.succinctBytes(j), DataSize.succinctBytes(j2));
        return map == null ? format : String.format("%s, Top Consumers: %s", format, ((ImmutableMap) map.entrySet().stream().sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).limit(3L).filter(entry -> {
            return ((Long) entry.getValue()).longValue() >= 0;
        }).collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry2 -> {
            return DataSize.succinctBytes(((Long) entry2.getValue()).longValue());
        }))).toString());
    }
}
