package io.trino.execution.executor.dedicated;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.errorprone.annotations.ThreadSafe;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.concurrent.Threads;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Tracer;
import io.trino.execution.SplitRunner;
import io.trino.execution.TaskId;
import io.trino.execution.TaskManagerConfig;
import io.trino.execution.executor.RunningSplitInfo;
import io.trino.execution.executor.TaskExecutor;
import io.trino.execution.executor.TaskHandle;
import io.trino.execution.executor.scheduler.FairScheduler;
import io.trino.spi.VersionEmbedder;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.DoubleSupplier;
import java.util.function.Predicate;

@ThreadSafe
/* loaded from: input_file:io/trino/execution/executor/dedicated/ThreadPerDriverTaskExecutor.class */
public class ThreadPerDriverTaskExecutor implements TaskExecutor {
    private static final Logger LOG = Logger.get(ThreadPerDriverTaskExecutor.class);
    private final FairScheduler scheduler;
    private final Tracer tracer;
    private final VersionEmbedder versionEmbedder;
    private final int targetGlobalLeafDrivers;
    private final int minDriversPerTask;
    private final int maxDriversPerTask;
    private final ScheduledThreadPoolExecutor backgroundTasks;

    @GuardedBy("this")
    private final Map<TaskId, TaskEntry> tasks;

    @GuardedBy("this")
    private boolean closed;

    @GuardedBy("this")
    private int runningLeafDrivers;

    @Inject
    public ThreadPerDriverTaskExecutor(TaskManagerConfig taskManagerConfig, Tracer tracer, VersionEmbedder versionEmbedder) {
        this(tracer, versionEmbedder, new FairScheduler(taskManagerConfig.getMaxWorkerThreads(), "SplitRunner-%d", Ticker.systemTicker()), taskManagerConfig.getMinDriversPerTask(), taskManagerConfig.getMaxDriversPerTask(), taskManagerConfig.getMinDrivers());
    }

    @VisibleForTesting
    public ThreadPerDriverTaskExecutor(Tracer tracer, VersionEmbedder versionEmbedder, FairScheduler fairScheduler, int i, int i2, int i3) {
        this.backgroundTasks = new ScheduledThreadPoolExecutor(2, Threads.daemonThreadsNamed("task-executor-scheduler-%s"));
        this.tasks = new HashMap();
        this.scheduler = fairScheduler;
        this.tracer = (Tracer) Objects.requireNonNull(tracer, "tracer is null");
        this.versionEmbedder = (VersionEmbedder) Objects.requireNonNull(versionEmbedder, "versionEmbedder is null");
        this.minDriversPerTask = i;
        this.maxDriversPerTask = i2;
        this.targetGlobalLeafDrivers = i3;
    }

    @Override // io.trino.execution.executor.TaskExecutor
    @PostConstruct
    public synchronized void start() {
        this.scheduler.start();
        this.backgroundTasks.scheduleWithFixedDelay(this::scheduleMoreLeafSplits, 0L, 100L, TimeUnit.MILLISECONDS);
        this.backgroundTasks.scheduleWithFixedDelay(this::adjustConcurrency, 0L, 10L, TimeUnit.MILLISECONDS);
        this.backgroundTasks.scheduleWithFixedDelay(this::logDiagnostics, 0L, 30L, TimeUnit.SECONDS);
    }

    @Override // io.trino.execution.executor.TaskExecutor
    @PreDestroy
    public synchronized void stop() {
        this.closed = true;
        this.tasks.values().forEach((v0) -> {
            v0.destroy();
        });
        this.backgroundTasks.shutdownNow();
        this.scheduler.close();
    }

    @Override // io.trino.execution.executor.TaskExecutor
    public synchronized TaskHandle addTask(TaskId taskId, DoubleSupplier doubleSupplier, int i, Duration duration, OptionalInt optionalInt) {
        Preconditions.checkArgument(!this.closed, "Executor is already closed");
        TaskEntry taskEntry = new TaskEntry(taskId, this.scheduler, this.versionEmbedder, this.tracer, i, doubleSupplier);
        this.tasks.put(taskId, taskEntry);
        return taskEntry;
    }

    @Override // io.trino.execution.executor.TaskExecutor
    public synchronized void removeTask(TaskHandle taskHandle) {
        TaskEntry taskEntry = (TaskEntry) taskHandle;
        this.tasks.remove(taskEntry.taskId());
        if (taskEntry.isDestroyed()) {
            return;
        }
        taskEntry.destroy();
    }

    @Override // io.trino.execution.executor.TaskExecutor
    public synchronized List<ListenableFuture<Void>> enqueueSplits(TaskHandle taskHandle, boolean z, List<? extends SplitRunner> list) {
        Preconditions.checkArgument(!this.closed, "Executor is already closed");
        TaskEntry taskEntry = (TaskEntry) taskHandle;
        ArrayList arrayList = new ArrayList();
        for (SplitRunner splitRunner : list) {
            if (z) {
                arrayList.add(taskEntry.runSplit(splitRunner));
            } else {
                arrayList.add(taskEntry.enqueueLeafSplit(splitRunner));
            }
        }
        scheduleMoreLeafSplits();
        return arrayList;
    }

    private boolean scheduleLeafSplit(TaskEntry taskEntry) {
        boolean dequeueAndRunLeafSplit = taskEntry.dequeueAndRunLeafSplit(this::leafSplitDone);
        if (dequeueAndRunLeafSplit) {
            this.runningLeafDrivers++;
        }
        return dequeueAndRunLeafSplit;
    }

    private synchronized void leafSplitDone() {
        this.runningLeafDrivers--;
        scheduleMoreLeafSplits();
    }

    private synchronized void scheduleMoreLeafSplits() {
        for (TaskEntry taskEntry : this.tasks.values()) {
            int max = Math.max(0, this.minDriversPerTask - taskEntry.runningLeafSplits());
            for (int i = 0; i < max && scheduleLeafSplit(taskEntry); i++) {
            }
        }
        ArrayDeque arrayDeque = new ArrayDeque(this.tasks.values());
        int i2 = this.targetGlobalLeafDrivers - this.runningLeafDrivers;
        for (int i3 = 0; i3 < i2 && !arrayDeque.isEmpty(); i3++) {
            TaskEntry taskEntry2 = (TaskEntry) arrayDeque.poll();
            if (taskEntry2.runningLeafSplits() < Math.min(taskEntry2.targetConcurrency(), this.maxDriversPerTask)) {
                scheduleLeafSplit(taskEntry2);
                if (taskEntry2.hasPendingLeafSplits()) {
                    arrayDeque.add(taskEntry2);
                }
            }
        }
    }

    private void adjustConcurrency() {
        Iterator<TaskEntry> it = this.tasks.values().iterator();
        while (it.hasNext()) {
            it.next().updateConcurrency();
        }
    }

    private void logDiagnostics() {
        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append("Queue:\n");
            sb.append(this.scheduler.diagnostics().indent(4));
            sb.append("Query tasks:\n");
            for (TaskEntry taskEntry : this.tasks.values()) {
                sb.append("%s: [total running = %s, leaf running = %s, leaf pending = %s, target concurrency = %s]\n".formatted(taskEntry.taskId(), Integer.valueOf(taskEntry.totalRunningSplits()), Integer.valueOf(taskEntry.runningLeafSplits()), Integer.valueOf(taskEntry.pendingLeafSplitCount()), Integer.valueOf(taskEntry.targetConcurrency())).indent(4));
            }
            LOG.debug("\n" + String.valueOf(sb));
        }
    }

    @Override // io.trino.execution.executor.TaskExecutor
    public Set<TaskId> getStuckSplitTaskIds(Duration duration, Predicate<RunningSplitInfo> predicate) {
        return ImmutableSet.of();
    }
}
