package io.trino.memory;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.MoreCollectors;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import com.google.common.io.Closer;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.concurrent.Threads;
import io.airlift.http.client.HttpClient;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.ExceededMemoryLimitException;
import io.trino.SystemSessionProperties;
import io.trino.execution.LocationFactory;
import io.trino.execution.QueryExecution;
import io.trino.execution.QueryInfo;
import io.trino.execution.StageInfo;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.execution.scheduler.NodeSchedulerConfig;
import io.trino.memory.LowMemoryKiller;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.metadata.NodeState;
import io.trino.operator.RetryPolicy;
import io.trino.server.BasicQueryInfo;
import io.trino.server.ServerConfig;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.memory.MemoryPoolInfo;
import jakarta.annotation.PreDestroy;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.weakref.jmx.JmxException;
import org.weakref.jmx.MBeanExporter;
import org.weakref.jmx.Managed;

/* loaded from: input_file:io/trino/memory/ClusterMemoryManager.class */
public class ClusterMemoryManager {
    private static final Logger log = Logger.get(ClusterMemoryManager.class);
    private static final String EXPORTED_POOL_NAME = "general";
    private final InternalNodeManager nodeManager;
    private final LocationFactory locationFactory;
    private final HttpClient httpClient;
    private final MBeanExporter exporter;
    private final JsonCodec<MemoryInfo> memoryInfoCodec;
    private final DataSize maxQueryMemory;
    private final DataSize maxQueryTotalMemory;
    private final boolean includeCoordinator;
    private final List<LowMemoryKiller> lowMemoryKillers;
    private final ClusterMemoryPool pool;
    private final ExecutorService listenerExecutor = Executors.newSingleThreadExecutor(Threads.daemonThreadsNamed("cluster-memory-manager-listener-%s"));
    private final ClusterMemoryLeakDetector memoryLeakDetector = new ClusterMemoryLeakDetector();
    private final AtomicLong totalAvailableProcessors = new AtomicLong();
    private final AtomicLong clusterUserMemoryReservation = new AtomicLong();
    private final AtomicLong clusterTotalMemoryReservation = new AtomicLong();
    private final AtomicLong clusterMemoryBytes = new AtomicLong();
    private final AtomicLong queriesKilledDueToOutOfMemory = new AtomicLong();
    private final AtomicLong tasksKilledDueToOutOfMemory = new AtomicLong();

    @GuardedBy("this")
    private final Map<String, RemoteNodeMemory> nodes = new HashMap();

    @GuardedBy("this")
    private final List<Consumer<MemoryPoolInfo>> changeListeners = new ArrayList();

    @GuardedBy("this")
    private Optional<KillTarget> lastKillTarget = Optional.empty();

    @Inject
    public ClusterMemoryManager(@ForMemoryManager HttpClient httpClient, InternalNodeManager internalNodeManager, LocationFactory locationFactory, MBeanExporter mBeanExporter, JsonCodec<MemoryInfo> jsonCodec, @LowMemoryKiller.ForTaskLowMemoryKiller LowMemoryKiller lowMemoryKiller, @LowMemoryKiller.ForQueryLowMemoryKiller LowMemoryKiller lowMemoryKiller2, ServerConfig serverConfig, MemoryManagerConfig memoryManagerConfig, NodeSchedulerConfig nodeSchedulerConfig) {
        Preconditions.checkState(serverConfig.isCoordinator(), "ClusterMemoryManager must not be bound on worker");
        this.nodeManager = (InternalNodeManager) Objects.requireNonNull(internalNodeManager, "nodeManager is null");
        this.locationFactory = (LocationFactory) Objects.requireNonNull(locationFactory, "locationFactory is null");
        this.httpClient = (HttpClient) Objects.requireNonNull(httpClient, "httpClient is null");
        this.exporter = (MBeanExporter) Objects.requireNonNull(mBeanExporter, "exporter is null");
        this.memoryInfoCodec = (JsonCodec) Objects.requireNonNull(jsonCodec, "memoryInfoCodec is null");
        Objects.requireNonNull(lowMemoryKiller, "taskLowMemoryKiller is null");
        Objects.requireNonNull(lowMemoryKiller2, "queryLowMemoryKiller is null");
        this.lowMemoryKillers = ImmutableList.of(lowMemoryKiller, lowMemoryKiller2);
        this.maxQueryMemory = memoryManagerConfig.getMaxQueryMemory();
        this.maxQueryTotalMemory = memoryManagerConfig.getMaxQueryTotalMemory();
        this.includeCoordinator = nodeSchedulerConfig.isIncludeCoordinator();
        Verify.verify(this.maxQueryMemory.toBytes() <= this.maxQueryTotalMemory.toBytes(), "maxQueryMemory cannot be greater than maxQueryTotalMemory", new Object[0]);
        this.pool = new ClusterMemoryPool();
        exportMemoryPool();
    }

    private void exportMemoryPool() {
        try {
            this.exporter.exportWithGeneratedName(this.pool, ClusterMemoryPool.class, "general");
        } catch (JmxException e) {
            log.error(e, "Error exporting memory pool");
        }
    }

    public synchronized void addChangeListener(Consumer<MemoryPoolInfo> consumer) {
        this.changeListeners.add(consumer);
    }

    public synchronized void process(Iterable<QueryExecution> iterable, Supplier<List<BasicQueryInfo>> supplier) {
        this.memoryLeakDetector.checkForMemoryLeaks(supplier, this.pool.getQueryMemoryReservations());
        boolean isClusterOutOfMemory = isClusterOutOfMemory();
        boolean z = false;
        long j = 0;
        long j2 = 0;
        for (QueryExecution queryExecution : iterable) {
            boolean resourceOvercommit = SystemSessionProperties.resourceOvercommit(queryExecution.getSession());
            long bytes = queryExecution.getUserMemoryReservation().toBytes();
            long bytes2 = queryExecution.getTotalMemoryReservation().toBytes();
            j += bytes;
            j2 += bytes2;
            if (SystemSessionProperties.getRetryPolicy(queryExecution.getSession()) != RetryPolicy.TASK) {
                if (resourceOvercommit && isClusterOutOfMemory) {
                    queryExecution.fail(new TrinoException(StandardErrorCode.CLUSTER_OUT_OF_MEMORY, String.format("The cluster is out of memory and %s=true, so this query was killed. It was using %s of memory", SystemSessionProperties.RESOURCE_OVERCOMMIT, DataSize.succinctBytes(getQueryMemoryReservation(queryExecution)))));
                    z = true;
                }
                if (!resourceOvercommit) {
                    long min = Math.min(this.maxQueryMemory.toBytes(), SystemSessionProperties.getQueryMaxMemory(queryExecution.getSession()).toBytes());
                    if (bytes > min) {
                        queryExecution.fail(ExceededMemoryLimitException.exceededGlobalUserLimit(DataSize.succinctBytes(min)));
                        z = true;
                    }
                    long min2 = Math.min(this.maxQueryTotalMemory.toBytes(), SystemSessionProperties.getQueryMaxTotalMemory(queryExecution.getSession()).toBytes());
                    if (bytes2 > min2) {
                        queryExecution.fail(ExceededMemoryLimitException.exceededGlobalTotalLimit(DataSize.succinctBytes(min2)));
                        z = true;
                    }
                }
            }
        }
        this.clusterUserMemoryReservation.set(j);
        this.clusterTotalMemoryReservation.set(j2);
        if (!this.lowMemoryKillers.isEmpty() && isClusterOutOfMemory && !z) {
            if (isLastKillTargetGone()) {
                callOomKiller(iterable);
            } else {
                log.debug("Last killed target is still not gone: %s", new Object[]{this.lastKillTarget});
            }
        }
        updateMemoryPool(Iterables.size(iterable));
        updateNodes();
    }

    private synchronized void callOomKiller(Iterable<QueryExecution> iterable) {
        List<LowMemoryKiller.RunningQueryInfo> list = (List) Streams.stream(iterable).map(this::createQueryMemoryInfo).collect(ImmutableList.toImmutableList());
        Map<String, MemoryInfo> map = (Map) this.nodes.entrySet().stream().filter(entry -> {
            return ((RemoteNodeMemory) entry.getValue()).getInfo().isPresent();
        }).collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry2 -> {
            return ((RemoteNodeMemory) entry2.getValue()).getInfo().get();
        }));
        Iterator<LowMemoryKiller> it = this.lowMemoryKillers.iterator();
        while (it.hasNext()) {
            Optional<KillTarget> chooseTargetToKill = it.next().chooseTargetToKill(list, ImmutableList.copyOf(map.values()));
            if (chooseTargetToKill.isPresent()) {
                if (chooseTargetToKill.get().isWholeQuery()) {
                    QueryId query = chooseTargetToKill.get().getQuery();
                    log.debug("Low memory killer chose %s", new Object[]{query});
                    Optional<QueryExecution> findRunningQuery = findRunningQuery(iterable, chooseTargetToKill.get().getQuery());
                    if (findRunningQuery.isPresent()) {
                        findRunningQuery.get().fail(new TrinoException(StandardErrorCode.CLUSTER_OUT_OF_MEMORY, "Query killed because the cluster is out of memory. Please try again in a few minutes."));
                        this.queriesKilledDueToOutOfMemory.incrementAndGet();
                        this.lastKillTarget = chooseTargetToKill;
                        logQueryKill(query, map);
                        return;
                    }
                    return;
                }
                Set<TaskId> tasks = chooseTargetToKill.get().getTasks();
                log.debug("Low memory killer chose %s", new Object[]{tasks});
                ImmutableSet.Builder builder = ImmutableSet.builder();
                for (TaskId taskId : tasks) {
                    Optional<QueryExecution> findRunningQuery2 = findRunningQuery(iterable, taskId.getQueryId());
                    if (findRunningQuery2.isPresent()) {
                        findRunningQuery2.get().failTask(taskId, new TrinoException(StandardErrorCode.CLUSTER_OUT_OF_MEMORY, "Task killed because the cluster is out of memory."));
                        this.tasksKilledDueToOutOfMemory.incrementAndGet();
                        builder.add(taskId);
                    }
                }
                ImmutableSet build = builder.build();
                if (build.isEmpty()) {
                    return;
                }
                this.lastKillTarget = Optional.of(KillTarget.selectedTasks(build));
                logTasksKill(build, map);
                return;
            }
        }
    }

    @GuardedBy("this")
    private boolean isLastKillTargetGone() {
        if (this.lastKillTarget.isEmpty()) {
            return true;
        }
        return this.lastKillTarget.get().isWholeQuery() ? isQueryGone(this.lastKillTarget.get().getQuery()) : areTasksGone(this.lastKillTarget.get().getTasks());
    }

    private boolean isQueryGone(QueryId queryId) {
        if (!this.memoryLeakDetector.wasQueryPossiblyLeaked(queryId)) {
            return !this.pool.getQueryMemoryReservations().containsKey(queryId);
        }
        this.lastKillTarget = Optional.empty();
        return true;
    }

    private boolean areTasksGone(Set<TaskId> set) {
        ImmutableSet<TaskId> runningTasks = getRunningTasks();
        Stream<TaskId> stream = set.stream();
        Objects.requireNonNull(runningTasks);
        return stream.noneMatch((v1) -> {
            return r1.contains(v1);
        });
    }

    private ImmutableSet<TaskId> getRunningTasks() {
        return (ImmutableSet) this.nodes.values().stream().map((v0) -> {
            return v0.getInfo();
        }).filter((v0) -> {
            return v0.isPresent();
        }).map((v0) -> {
            return v0.get();
        }).flatMap(memoryInfo -> {
            return memoryInfo.getPool().getTaskMemoryReservations().keySet().stream();
        }).map(TaskId::valueOf).collect(ImmutableSet.toImmutableSet());
    }

    private Optional<QueryExecution> findRunningQuery(Iterable<QueryExecution> iterable, QueryId queryId) {
        return (Optional) Streams.stream(iterable).filter(queryExecution -> {
            return queryId.equals(queryExecution.getQueryId());
        }).collect(MoreCollectors.toOptional());
    }

    private void logQueryKill(QueryId queryId, Map<String, MemoryInfo> map) {
        if (log.isInfoEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append("Query Kill Decision: Killed ").append(queryId).append("\n");
            sb.append(formatKillScenario(map));
            log.info("%s", new Object[]{sb});
        }
    }

    private void logTasksKill(Set<TaskId> set, Map<String, MemoryInfo> map) {
        if (log.isInfoEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append("Query Kill Decision: Tasks Killed ").append(set).append("\n");
            sb.append(formatKillScenario(map));
            log.info("%s", new Object[]{sb});
        }
    }

    private String formatKillScenario(Map<String, MemoryInfo> map) {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, MemoryInfo> entry : map.entrySet()) {
            String key = entry.getKey();
            MemoryPoolInfo pool = entry.getValue().getPool();
            sb.append("Node[").append(key).append("]: ");
            sb.append("MaxBytes ").append(pool.getMaxBytes()).append(' ');
            sb.append("FreeBytes ").append(pool.getFreeBytes() + pool.getReservedRevocableBytes()).append(' ');
            sb.append("Queries ");
            Joiner.on(",").withKeyValueSeparator("=").appendTo(sb, pool.getQueryMemoryReservations()).append(' ');
            sb.append("Tasks ");
            Joiner.on(",").withKeyValueSeparator("=").appendTo(sb, pool.getTaskMemoryReservations());
            sb.append('\n');
        }
        return sb.toString();
    }

    @VisibleForTesting
    ClusterMemoryPool getPool() {
        return this.pool;
    }

    private boolean isClusterOutOfMemory() {
        return this.pool.getBlockedNodes() > 0;
    }

    private LowMemoryKiller.RunningQueryInfo createQueryMemoryInfo(QueryExecution queryExecution) {
        QueryInfo queryInfo = queryExecution.getQueryInfo();
        ImmutableMap.Builder builder = ImmutableMap.builder();
        queryInfo.getOutputStage().ifPresent(stageInfo -> {
            getTaskInfos(stageInfo, builder);
        });
        return new LowMemoryKiller.RunningQueryInfo(queryExecution.getQueryId(), queryExecution.getTotalMemoryReservation().toBytes(), builder.buildOrThrow(), SystemSessionProperties.getRetryPolicy(queryExecution.getSession()));
    }

    private void getTaskInfos(StageInfo stageInfo, ImmutableMap.Builder<TaskId, TaskInfo> builder) {
        for (TaskInfo taskInfo : stageInfo.getTasks()) {
            builder.put(taskInfo.taskStatus().getTaskId(), taskInfo);
        }
        Iterator<StageInfo> it = stageInfo.getSubStages().iterator();
        while (it.hasNext()) {
            getTaskInfos(it.next(), builder);
        }
    }

    private long getQueryMemoryReservation(QueryExecution queryExecution) {
        return queryExecution.getTotalMemoryReservation().toBytes();
    }

    private synchronized void updateNodes() {
        ImmutableSet<InternalNode> build = ImmutableSet.builder().addAll(this.nodeManager.getNodes(NodeState.ACTIVE)).addAll(this.nodeManager.getNodes(NodeState.SHUTTING_DOWN)).build();
        this.nodes.keySet().removeAll(ImmutableSet.copyOf(Sets.difference(this.nodes.keySet(), (ImmutableSet) build.stream().map((v0) -> {
            return v0.getNodeIdentifier();
        }).collect(ImmutableSet.toImmutableSet()))));
        for (InternalNode internalNode : build) {
            if (!this.nodes.containsKey(internalNode.getNodeIdentifier())) {
                this.nodes.put(internalNode.getNodeIdentifier(), new RemoteNodeMemory(internalNode, this.httpClient, this.memoryInfoCodec, this.locationFactory.createMemoryInfoLocation(internalNode)));
            }
        }
        Iterator<RemoteNodeMemory> it = this.nodes.values().iterator();
        while (it.hasNext()) {
            it.next().asyncRefresh();
        }
    }

    private synchronized void updateMemoryPool(int i) {
        List<MemoryInfo> list = (List) this.nodes.values().stream().map((v0) -> {
            return v0.getInfo();
        }).filter((v0) -> {
            return v0.isPresent();
        }).map((v0) -> {
            return v0.get();
        }).collect(ImmutableList.toImmutableList());
        this.totalAvailableProcessors.set(list.stream().mapToLong((v0) -> {
            return v0.getAvailableProcessors();
        }).sum());
        this.clusterMemoryBytes.set(list.stream().mapToLong(memoryInfo -> {
            return memoryInfo.getPool().getMaxBytes();
        }).sum());
        this.pool.update(list, i);
        if (this.changeListeners.isEmpty()) {
            return;
        }
        MemoryPoolInfo info = this.pool.getInfo();
        for (Consumer<MemoryPoolInfo> consumer : this.changeListeners) {
            this.listenerExecutor.execute(() -> {
                consumer.accept(info);
            });
        }
    }

    public synchronized Map<String, Optional<MemoryInfo>> getWorkersMemoryInfo() {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, RemoteNodeMemory> entry : this.nodes.entrySet()) {
            if (this.includeCoordinator || !entry.getValue().getNode().isCoordinator()) {
                hashMap.put(entry.getKey(), entry.getValue().getInfo());
            }
        }
        return hashMap;
    }

    public synchronized Map<String, Optional<MemoryInfo>> getAllNodesMemoryInfo() {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, RemoteNodeMemory> entry : this.nodes.entrySet()) {
            hashMap.put(entry.getKey(), entry.getValue().getInfo());
        }
        return hashMap;
    }

    @PreDestroy
    public synchronized void destroy() throws IOException {
        Closer create = Closer.create();
        try {
            create.register(() -> {
                this.exporter.unexportWithGeneratedName(ClusterMemoryPool.class, "general");
            });
            ExecutorService executorService = this.listenerExecutor;
            Objects.requireNonNull(executorService);
            create.register(executorService::shutdownNow);
            if (create != null) {
                create.close();
            }
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Managed
    public long getTotalAvailableProcessors() {
        return this.totalAvailableProcessors.get();
    }

    @Managed
    public int getNumberOfLeakedQueries() {
        return this.memoryLeakDetector.getNumberOfLeakedQueries();
    }

    @Managed
    public long getClusterUserMemoryReservation() {
        return this.clusterUserMemoryReservation.get();
    }

    @Managed
    public long getClusterTotalMemoryReservation() {
        return this.clusterTotalMemoryReservation.get();
    }

    @Managed
    public long getClusterMemoryBytes() {
        return this.clusterMemoryBytes.get();
    }

    @Managed
    public long getQueriesKilledDueToOutOfMemory() {
        return this.queriesKilledDueToOutOfMemory.get();
    }

    @Managed
    public long getTasksKilledDueToOutOfMemory() {
        return this.tasksKilledDueToOutOfMemory.get();
    }
}
