/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.inject.Inject;
import io.airlift.concurrent.ThreadPoolExecutorMBean;
import io.airlift.concurrent.Threads;
import io.airlift.log.Logger;
import io.airlift.node.NodeInfo;
import io.airlift.stats.CounterStat;
import io.airlift.stats.GcMonitor;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.Tracer;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cache.NonEvictableLoadingCache;
import io.trino.cache.SafeCaches;
import io.trino.connector.ConnectorServicesProvider;
import io.trino.event.SplitMonitor;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.DynamicFiltersCollector;
import io.trino.execution.LocationFactory;
import io.trino.execution.SplitAssignment;
import io.trino.execution.SqlTask;
import io.trino.execution.SqlTaskExecutionFactory;
import io.trino.execution.SqlTaskIoStats;
import io.trino.execution.StateMachine;
import io.trino.execution.TaskFailureListener;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.execution.TaskManagementExecutor;
import io.trino.execution.TaskManagerConfig;
import io.trino.execution.TaskState;
import io.trino.execution.TaskStatus;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.execution.executor.timesharing.PrioritizedSplitRunner;
import io.trino.memory.LocalMemoryManager;
import io.trino.memory.NodeMemoryConfig;
import io.trino.memory.QueryContext;
import io.trino.metadata.LanguageFunctionProvider;
import io.trino.operator.RetryPolicy;
import io.trino.operator.scalar.JoniRegexpFunctions;
import io.trino.operator.scalar.JoniRegexpReplaceLambdaFunction;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.VersionEmbedder;
import io.trino.spi.catalog.CatalogProperties;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.function.FunctionId;
import io.trino.spi.predicate.Domain;
import io.trino.spiller.LocalSpillManager;
import io.trino.spiller.NodeSpillConfig;
import io.trino.sql.planner.LocalExecutionPlanner;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.routine.ir.IrRoutine;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.io.Closeable;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.joda.time.DateTime;
import org.joda.time.ReadableInstant;
import org.weakref.jmx.Flatten;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

public class SqlTaskManager
implements Closeable {
    private static final Logger log = Logger.get(SqlTaskManager.class);
    private static final Set<String> JONI_REGEXP_FUNCTION_CLASS_NAMES = ImmutableSet.of((Object)JoniRegexpFunctions.class.getName(), (Object)JoniRegexpReplaceLambdaFunction.class.getName());
    private static final Predicate<List<StackTraceElement>> STUCK_SPLIT_STACK_TRACE_PREDICATE = elements -> elements.stream().anyMatch(stackTraceElement -> JONI_REGEXP_FUNCTION_CLASS_NAMES.contains(stackTraceElement.getClassName()));
    private final VersionEmbedder versionEmbedder;
    private final ConnectorServicesProvider connectorServicesProvider;
    private final ExecutorService taskNotificationExecutor;
    private final ThreadPoolExecutorMBean taskNotificationExecutorMBean;
    private final ScheduledExecutorService taskManagementExecutor;
    private final ScheduledExecutorService driverYieldExecutor;
    private final ScheduledExecutorService driverTimeoutExecutor;
    private final Duration infoCacheTime;
    private final Duration clientTimeout;
    private final NonEvictableLoadingCache<QueryId, QueryContext> queryContexts;
    private final NonEvictableLoadingCache<TaskId, SqlTask> tasks;
    private final SqlTaskIoStats cachedStats = new SqlTaskIoStats();
    private final SqlTaskIoStats finishedTaskStats = new SqlTaskIoStats();
    private final long queryMaxMemoryPerNode;
    private final CounterStat failedTasks = new CounterStat();
    private final Optional<StuckSplitTasksInterrupter> stuckSplitTasksInterrupter;
    private final LanguageFunctionProvider languageFunctionProvider;
    private final ReentrantReadWriteLock catalogsLock = new ReentrantReadWriteLock();

    @Inject
    public SqlTaskManager(VersionEmbedder versionEmbedder, ConnectorServicesProvider connectorServicesProvider, LocalExecutionPlanner planner, LanguageFunctionProvider languageFunctionProvider, LocationFactory locationFactory, TaskExecutor taskExecutor, SplitMonitor splitMonitor, NodeInfo nodeInfo, LocalMemoryManager localMemoryManager, TaskManagementExecutor taskManagementExecutor, TaskManagerConfig config, NodeMemoryConfig nodeMemoryConfig, LocalSpillManager localSpillManager, NodeSpillConfig nodeSpillConfig, GcMonitor gcMonitor, Tracer tracer, ExchangeManagerRegistry exchangeManagerRegistry) {
        this(versionEmbedder, connectorServicesProvider, planner, languageFunctionProvider, locationFactory, taskExecutor, splitMonitor, nodeInfo, localMemoryManager, taskManagementExecutor, config, nodeMemoryConfig, localSpillManager, nodeSpillConfig, gcMonitor, tracer, exchangeManagerRegistry, STUCK_SPLIT_STACK_TRACE_PREDICATE);
    }

    @VisibleForTesting
    public SqlTaskManager(VersionEmbedder versionEmbedder, ConnectorServicesProvider connectorServicesProvider, LocalExecutionPlanner planner, LanguageFunctionProvider languageFunctionProvider, LocationFactory locationFactory, TaskExecutor taskExecutor, SplitMonitor splitMonitor, NodeInfo nodeInfo, LocalMemoryManager localMemoryManager, TaskManagementExecutor taskManagementExecutor, TaskManagerConfig config, NodeMemoryConfig nodeMemoryConfig, LocalSpillManager localSpillManager, NodeSpillConfig nodeSpillConfig, GcMonitor gcMonitor, Tracer tracer, ExchangeManagerRegistry exchangeManagerRegistry, Predicate<List<StackTraceElement>> stuckSplitStackTracePredicate) {
        this.connectorServicesProvider = Objects.requireNonNull(connectorServicesProvider, "connectorServicesProvider is null");
        this.languageFunctionProvider = languageFunctionProvider;
        Objects.requireNonNull(nodeInfo, "nodeInfo is null");
        this.infoCacheTime = config.getInfoMaxAge();
        this.clientTimeout = config.getClientTimeout();
        DataSize maxBufferSize = config.getSinkMaxBufferSize();
        DataSize maxBroadcastBufferSize = config.getSinkMaxBroadcastBufferSize();
        this.versionEmbedder = Objects.requireNonNull(versionEmbedder, "versionEmbedder is null");
        this.taskNotificationExecutor = Executors.newFixedThreadPool(config.getTaskNotificationThreads(), Threads.threadsNamed((String)"task-notification-%s"));
        this.taskNotificationExecutorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor)this.taskNotificationExecutor);
        this.taskManagementExecutor = taskManagementExecutor.getExecutor();
        this.driverYieldExecutor = Executors.newScheduledThreadPool(config.getTaskYieldThreads(), Threads.threadsNamed((String)"task-yield-%s"));
        this.driverTimeoutExecutor = Executors.newScheduledThreadPool(config.getDriverTimeoutThreads(), Threads.threadsNamed((String)"task-driver-timeout-%s"));
        SqlTaskExecutionFactory sqlTaskExecutionFactory = new SqlTaskExecutionFactory(this.taskNotificationExecutor, taskExecutor, planner, splitMonitor, tracer, config);
        DataSize maxQueryMemoryPerNode = nodeMemoryConfig.getMaxQueryMemoryPerNode();
        DataSize maxQuerySpillPerNode = nodeSpillConfig.getQueryMaxSpillPerNode();
        this.queryMaxMemoryPerNode = maxQueryMemoryPerNode.toBytes();
        this.queryContexts = SafeCaches.buildNonEvictableCache((CacheBuilder)CacheBuilder.newBuilder().weakValues(), (CacheLoader)CacheLoader.from(queryId -> this.createQueryContext((QueryId)queryId, localMemoryManager, localSpillManager, gcMonitor, maxQueryMemoryPerNode, maxQuerySpillPerNode)));
        this.tasks = SafeCaches.buildNonEvictableCache((CacheBuilder)CacheBuilder.newBuilder(), (CacheLoader)CacheLoader.from(taskId -> SqlTask.createSqlTask(taskId, locationFactory.createLocalTaskLocation((TaskId)taskId), nodeInfo.getNodeId(), (QueryContext)this.queryContexts.getUnchecked((Object)taskId.getQueryId()), tracer, sqlTaskExecutionFactory, this.taskNotificationExecutor, sqlTask -> {
            languageFunctionProvider.unregisterTask((TaskId)taskId);
            this.finishedTaskStats.merge(sqlTask.getIoStats());
        }, maxBufferSize, maxBroadcastBufferSize, Objects.requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"), this.failedTasks)));
        this.stuckSplitTasksInterrupter = this.createStuckSplitTasksInterrupter(config.isInterruptStuckSplitTasksEnabled(), config.getInterruptStuckSplitTasksWarningThreshold(), config.getInterruptStuckSplitTasksTimeout(), config.getInterruptStuckSplitTasksDetectionInterval(), stuckSplitStackTracePredicate, taskExecutor);
    }

    private QueryContext createQueryContext(QueryId queryId, LocalMemoryManager localMemoryManager, LocalSpillManager localSpillManager, GcMonitor gcMonitor, DataSize maxQueryUserMemoryPerNode, DataSize maxQuerySpillPerNode) {
        return new QueryContext(queryId, maxQueryUserMemoryPerNode, localMemoryManager.getMemoryPool(), gcMonitor, this.taskNotificationExecutor, this.driverYieldExecutor, this.driverTimeoutExecutor, maxQuerySpillPerNode, localSpillManager.getSpillSpaceTracker());
    }

    @PostConstruct
    public void start() {
        this.taskManagementExecutor.scheduleWithFixedDelay(() -> {
            try {
                this.removeOldTasks();
            }
            catch (Throwable e) {
                log.warn(e, "Error removing old tasks");
            }
            try {
                this.failAbandonedTasks();
            }
            catch (Throwable e) {
                log.warn(e, "Error canceling abandoned tasks");
            }
        }, 200L, 200L, TimeUnit.MILLISECONDS);
        this.taskManagementExecutor.scheduleWithFixedDelay(() -> {
            try {
                this.updateStats();
            }
            catch (Throwable e) {
                log.warn(e, "Error updating stats");
            }
        }, 0L, 1L, TimeUnit.SECONDS);
        this.stuckSplitTasksInterrupter.ifPresent(interrupter -> {
            long intervalSeconds = interrupter.getStuckSplitsDetectionInterval().roundTo(TimeUnit.SECONDS);
            this.taskManagementExecutor.scheduleAtFixedRate(() -> {
                try {
                    this.failStuckSplitTasks();
                }
                catch (Throwable e) {
                    log.warn(e, "Error failing stuck split tasks");
                }
            }, 0L, intervalSeconds, TimeUnit.SECONDS);
        });
    }

    @Override
    @PreDestroy
    public void close() {
        boolean taskCanceled = false;
        for (SqlTask task : this.tasks.asMap().values()) {
            if (task.getTaskState().isDone()) continue;
            task.failed(new TrinoException((ErrorCodeSupplier)StandardErrorCode.SERVER_SHUTTING_DOWN, String.format("Server is shutting down. Task %s has been canceled", task.getTaskId())));
            taskCanceled = true;
        }
        if (taskCanceled) {
            try {
                TimeUnit.SECONDS.sleep(5L);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        this.taskNotificationExecutor.shutdownNow();
        this.driverYieldExecutor.shutdownNow();
        this.driverTimeoutExecutor.shutdownNow();
    }

    @Managed
    @Flatten
    public SqlTaskIoStats getIoStats() {
        return this.cachedStats;
    }

    @Managed(description="Task notification executor")
    @Nested
    public ThreadPoolExecutorMBean getTaskNotificationExecutor() {
        return this.taskNotificationExecutorMBean;
    }

    @Managed(description="Failed tasks counter")
    @Nested
    public CounterStat getFailedTasks() {
        return this.failedTasks;
    }

    public List<SqlTask> getAllTasks() {
        return ImmutableList.copyOf(this.tasks.asMap().values());
    }

    public List<TaskInfo> getAllTaskInfo() {
        return (List)this.tasks.asMap().values().stream().map(SqlTask::getTaskInfo).collect(ImmutableList.toImmutableList());
    }

    public TaskInfo getTaskInfo(TaskId taskId) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskInfo();
    }

    public TaskStatus getTaskStatus(TaskId taskId) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskStatus();
    }

    public ListenableFuture<TaskInfo> getTaskInfo(TaskId taskId, long currentVersion) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskInfo(currentVersion);
    }

    public ListenableFuture<TaskStatus> getTaskStatus(TaskId taskId, long currentVersion) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.getTaskStatus(currentVersion);
    }

    public DynamicFiltersCollector.VersionedDynamicFilterDomains acknowledgeAndGetNewDynamicFilterDomains(TaskId taskId, long currentDynamicFiltersVersion) {
        Objects.requireNonNull(taskId, "taskId is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        sqlTask.recordHeartbeat();
        return sqlTask.acknowledgeAndGetNewDynamicFilterDomains(currentDynamicFiltersVersion);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void pruneCatalogs(Set<CatalogHandle> activeCatalogs) {
        HashSet<CatalogHandle> catalogsInUse = new HashSet<CatalogHandle>(activeCatalogs);
        ReentrantReadWriteLock.WriteLock pruneLock = this.catalogsLock.writeLock();
        pruneLock.lock();
        try {
            for (SqlTask task : this.tasks.asMap().values()) {
                if (task.getTaskState().isDone()) continue;
                catalogsInUse.addAll((Collection<CatalogHandle>)task.getCatalogs().orElse((Set<CatalogHandle>)ImmutableSet.of()));
            }
            this.connectorServicesProvider.pruneCatalogs(catalogsInUse);
        }
        finally {
            pruneLock.unlock();
        }
    }

    public TaskInfo updateTask(Session session, TaskId taskId, Span stageSpan, Optional<PlanFragment> fragment, List<SplitAssignment> splitAssignments, OutputBuffers outputBuffers, Map<DynamicFilterId, Domain> dynamicFilterDomains, boolean speculative) {
        try {
            return (TaskInfo)this.versionEmbedder.embedVersion(() -> this.doUpdateTask(session, taskId, stageSpan, fragment, splitAssignments, outputBuffers, dynamicFilterDomains, speculative)).call();
        }
        catch (Exception e) {
            Throwables.throwIfUnchecked((Throwable)e);
            throw new RuntimeException(e);
        }
    }

    private TaskInfo doUpdateTask(Session session, TaskId taskId, Span stageSpan, Optional<PlanFragment> fragment, List<SplitAssignment> splitAssignments, OutputBuffers outputBuffers, Map<DynamicFilterId, Domain> dynamicFilterDomains, boolean speculative) {
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(stageSpan, "stageSpan is null");
        Objects.requireNonNull(fragment, "fragment is null");
        Objects.requireNonNull(splitAssignments, "splitAssignments is null");
        Objects.requireNonNull(outputBuffers, "outputBuffers is null");
        SqlTask sqlTask = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        QueryContext queryContext = sqlTask.getQueryContext();
        if (!queryContext.isMemoryLimitsInitialized()) {
            RetryPolicy retryPolicy = SystemSessionProperties.getRetryPolicy(session);
            if (retryPolicy == RetryPolicy.TASK) {
                queryContext.initializeMemoryLimits(false, Long.MAX_VALUE);
            } else {
                long sessionQueryMaxMemoryPerNode = SystemSessionProperties.getQueryMaxMemoryPerNode(session).toBytes();
                queryContext.initializeMemoryLimits(SystemSessionProperties.resourceOvercommit(session), Math.min(sessionQueryMaxMemoryPerNode, this.queryMaxMemoryPerNode));
            }
        }
        fragment.map(PlanFragment::getActiveCatalogs).ifPresent(activeCatalogs -> {
            Set catalogHandles = (Set)activeCatalogs.stream().map(CatalogProperties::catalogHandle).collect(ImmutableSet.toImmutableSet());
            sqlTask.setCatalogs(catalogHandles);
            if (!sqlTask.catalogsLoaded()) {
                ReentrantReadWriteLock.ReadLock catalogInitLock = this.catalogsLock.readLock();
                catalogInitLock.lock();
                try {
                    this.connectorServicesProvider.ensureCatalogsLoaded(session, (List<CatalogProperties>)activeCatalogs);
                    sqlTask.setCatalogsLoaded();
                }
                finally {
                    catalogInitLock.unlock();
                }
            }
        });
        fragment.map(PlanFragment::getLanguageFunctions).ifPresent(languageFunctions -> this.languageFunctionProvider.registerTask(taskId, (Map<FunctionId, IrRoutine>)languageFunctions));
        sqlTask.recordHeartbeat();
        return sqlTask.updateTask(session, stageSpan, fragment, splitAssignments, outputBuffers, dynamicFilterDomains, speculative);
    }

    public SqlTaskWithResults getTaskResults(TaskId taskId, PipelinedOutputBuffers.OutputBufferId bufferId, long startingSequenceId, DataSize maxSize) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(bufferId, "bufferId is null");
        Preconditions.checkArgument((startingSequenceId >= 0L ? 1 : 0) != 0, (Object)"startingSequenceId is negative");
        Objects.requireNonNull(maxSize, "maxSize is null");
        SqlTask task = (SqlTask)this.tasks.getUnchecked((Object)taskId);
        return new SqlTaskWithResults(task, task.getTaskResults(bufferId, startingSequenceId, maxSize));
    }

    public void acknowledgeTaskResults(TaskId taskId, PipelinedOutputBuffers.OutputBufferId bufferId, long sequenceId) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(bufferId, "bufferId is null");
        Preconditions.checkArgument((sequenceId >= 0L ? 1 : 0) != 0, (Object)"sequenceId is negative");
        ((SqlTask)this.tasks.getUnchecked((Object)taskId)).acknowledgeTaskResults(bufferId, sequenceId);
    }

    public TaskInfo destroyTaskResults(TaskId taskId, PipelinedOutputBuffers.OutputBufferId bufferId) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(bufferId, "bufferId is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).destroyTaskResults(bufferId);
    }

    public TaskInfo cancelTask(TaskId taskId) {
        Objects.requireNonNull(taskId, "taskId is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).cancel();
    }

    public TaskInfo abortTask(TaskId taskId) {
        Objects.requireNonNull(taskId, "taskId is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).abort();
    }

    public TaskInfo failTask(TaskId taskId, Throwable failure) {
        Objects.requireNonNull(taskId, "taskId is null");
        Objects.requireNonNull(failure, "failure is null");
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).failed(failure);
    }

    @VisibleForTesting
    void removeOldTasks() {
        DateTime oldestAllowedTask = DateTime.now().minus(this.infoCacheTime.toMillis());
        this.tasks.asMap().values().stream().map(SqlTask::getTaskInfo).filter(Objects::nonNull).forEach(taskInfo -> {
            TaskId taskId = taskInfo.taskStatus().getTaskId();
            try {
                DateTime endTime = taskInfo.stats().getEndTime();
                if (endTime != null && endTime.isBefore((ReadableInstant)oldestAllowedTask)) {
                    this.tasks.unsafeInvalidate((Object)taskId);
                }
            }
            catch (RuntimeException e) {
                log.warn((Throwable)e, "Error while inspecting age of complete task %s", new Object[]{taskId});
            }
        });
    }

    private void failAbandonedTasks() {
        DateTime now = DateTime.now();
        DateTime oldestAllowedHeartbeat = now.minus(this.clientTimeout.toMillis());
        for (SqlTask sqlTask : this.tasks.asMap().values()) {
            try {
                DateTime lastHeartbeat;
                TaskInfo taskInfo = sqlTask.getTaskInfo();
                TaskStatus taskStatus = taskInfo.taskStatus();
                if (taskStatus.getState().isDone() || (lastHeartbeat = taskInfo.lastHeartbeat()) == null || !lastHeartbeat.isBefore((ReadableInstant)oldestAllowedHeartbeat)) continue;
                log.info("Failing abandoned task %s", new Object[]{taskStatus.getTaskId()});
                sqlTask.failed(new TrinoException((ErrorCodeSupplier)StandardErrorCode.ABANDONED_TASK, String.format("Task %s has not been accessed since %s: currentTime %s", taskStatus.getTaskId(), lastHeartbeat, now)));
            }
            catch (RuntimeException e) {
                log.warn((Throwable)e, "Error while inspecting age of task %s", new Object[]{sqlTask.getTaskId()});
            }
        }
    }

    private void updateStats() {
        SqlTaskIoStats tempIoStats = new SqlTaskIoStats();
        tempIoStats.merge(this.finishedTaskStats);
        this.tasks.asMap().values().stream().filter(task -> !task.getTaskState().isDone()).forEach(task -> tempIoStats.merge(task.getIoStats()));
        this.cachedStats.resetTo(tempIoStats);
    }

    public void addStateChangeListener(TaskId taskId, StateMachine.StateChangeListener<TaskState> stateChangeListener) {
        Objects.requireNonNull(taskId, "taskId is null");
        ((SqlTask)this.tasks.getUnchecked((Object)taskId)).addStateChangeListener(stateChangeListener);
    }

    public void addSourceTaskFailureListener(TaskId taskId, TaskFailureListener listener) {
        ((SqlTask)this.tasks.getUnchecked((Object)taskId)).addSourceTaskFailureListener(listener);
    }

    public Optional<String> getTraceToken(TaskId taskId) {
        return ((SqlTask)this.tasks.getUnchecked((Object)taskId)).getTraceToken();
    }

    @VisibleForTesting
    public QueryContext getQueryContext(QueryId queryId) {
        return (QueryContext)this.queryContexts.getUnchecked((Object)queryId);
    }

    @VisibleForTesting
    public void failStuckSplitTasks() {
        this.stuckSplitTasksInterrupter.ifPresent(StuckSplitTasksInterrupter::failStuckSplitTasks);
    }

    private Optional<StuckSplitTasksInterrupter> createStuckSplitTasksInterrupter(boolean enableInterruptStuckSplitTasks, Duration stuckSplitsWarningThreshold, Duration interruptStuckSplitTasksTimeout, Duration stuckSplitsDetectionInterval, Predicate<List<StackTraceElement>> stuckSplitStackTracePredicate, TaskExecutor taskExecutor) {
        if (!enableInterruptStuckSplitTasks) {
            return Optional.empty();
        }
        return Optional.of(new StuckSplitTasksInterrupter(stuckSplitsWarningThreshold, interruptStuckSplitTasksTimeout, stuckSplitsDetectionInterval, stuckSplitStackTracePredicate, taskExecutor));
    }

    public static final class SqlTaskWithResults {
        private final SqlTask task;
        private final ListenableFuture<BufferResult> resultsFuture;

        public SqlTaskWithResults(SqlTask task, ListenableFuture<BufferResult> resultsFuture) {
            this.task = Objects.requireNonNull(task, "task is null");
            this.resultsFuture = Objects.requireNonNull(resultsFuture, "resultsFuture is null");
        }

        public void recordHeartbeat() {
            this.task.recordHeartbeat();
        }

        public String getTaskInstanceId() {
            return this.task.getTaskInstanceId();
        }

        public boolean isTaskFailedOrFailing() {
            return switch (this.task.getTaskState()) {
                case TaskState.ABORTED, TaskState.ABORTING, TaskState.FAILED, TaskState.FAILING -> true;
                default -> false;
            };
        }

        public ListenableFuture<BufferResult> getResultsFuture() {
            return this.resultsFuture;
        }
    }

    private class StuckSplitTasksInterrupter {
        private final Duration interruptStuckSplitTasksTimeout;
        private final Duration stuckSplitsDetectionInterval;
        private final Predicate<List<StackTraceElement>> stuckSplitStackTracePredicate;
        private final TaskExecutor taskExecutor;

        public StuckSplitTasksInterrupter(Duration stuckSplitsWarningThreshold, Duration interruptStuckSplitTasksTimeout, Duration stuckSplitDetectionInterval, Predicate<List<StackTraceElement>> stuckSplitStackTracePredicate, TaskExecutor taskExecutor) {
            Preconditions.checkArgument((interruptStuckSplitTasksTimeout.compareTo(PrioritizedSplitRunner.SPLIT_RUN_QUANTA) >= 0 ? 1 : 0) != 0, (String)"interruptStuckSplitTasksTimeout must be at least %s", (Object)PrioritizedSplitRunner.SPLIT_RUN_QUANTA);
            Preconditions.checkArgument((stuckSplitsWarningThreshold.compareTo(interruptStuckSplitTasksTimeout) <= 0 ? 1 : 0) != 0, (Object)"interruptStuckSplitTasksTimeout cannot be less than stuckSplitsWarningThreshold");
            this.interruptStuckSplitTasksTimeout = Objects.requireNonNull(interruptStuckSplitTasksTimeout, "interruptStuckSplitTasksTimeout is null");
            this.stuckSplitsDetectionInterval = Objects.requireNonNull(stuckSplitDetectionInterval, "stuckSplitsDetectionInterval is null");
            this.stuckSplitStackTracePredicate = Objects.requireNonNull(stuckSplitStackTracePredicate, "stuckSplitStackTracePredicate is null");
            this.taskExecutor = Objects.requireNonNull(taskExecutor, "taskExecutor is null");
        }

        public Duration getStuckSplitsDetectionInterval() {
            return this.stuckSplitsDetectionInterval;
        }

        private void failStuckSplitTasks() {
            Set<TaskId> stuckSplitTaskIds = this.taskExecutor.getStuckSplitTaskIds(this.interruptStuckSplitTasksTimeout, splitInfo -> {
                List<StackTraceElement> stackTraceElements = Arrays.asList(splitInfo.getThread().getStackTrace());
                if (!splitInfo.isPrinted()) {
                    splitInfo.setPrinted();
                    log.warn("%s is long running with stackTrace:\n%s", new Object[]{splitInfo.getSplitInfo(), stackTraceElements.stream().map(Object::toString).collect(Collectors.joining(System.lineSeparator()))});
                }
                return this.stuckSplitStackTracePredicate.test(stackTraceElements);
            });
            for (TaskId stuckSplitTaskId : stuckSplitTaskIds) {
                SqlTaskManager.this.failTask(stuckSplitTaskId, new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_USER_ERROR, String.format("Task %s is failed, due to containing long running stuck splits.", stuckSplitTaskId)));
            }
        }
    }
}

