/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.execution.scheduler;

import com.facebook.presto.OutputBuffers;
import com.facebook.presto.Session;
import com.facebook.presto.execution.LocationFactory;
import com.facebook.presto.execution.NodeTaskMap;
import com.facebook.presto.execution.QueryState;
import com.facebook.presto.execution.QueryStateMachine;
import com.facebook.presto.execution.RemoteTask;
import com.facebook.presto.execution.RemoteTaskFactory;
import com.facebook.presto.execution.SqlStageExecution;
import com.facebook.presto.execution.StageId;
import com.facebook.presto.execution.StageInfo;
import com.facebook.presto.execution.StageState;
import com.facebook.presto.execution.scheduler.BroadcastOutputBufferManager;
import com.facebook.presto.execution.scheduler.DynamicSplitPlacementPolicy;
import com.facebook.presto.execution.scheduler.ExecutionPolicy;
import com.facebook.presto.execution.scheduler.ExecutionSchedule;
import com.facebook.presto.execution.scheduler.FixedCountScheduler;
import com.facebook.presto.execution.scheduler.FixedSourcePartitionedScheduler;
import com.facebook.presto.execution.scheduler.NodeScheduler;
import com.facebook.presto.execution.scheduler.NodeSelector;
import com.facebook.presto.execution.scheduler.OutputBufferManager;
import com.facebook.presto.execution.scheduler.PartitionedOutputBufferManager;
import com.facebook.presto.execution.scheduler.ScheduleResult;
import com.facebook.presto.execution.scheduler.SourcePartitionedScheduler;
import com.facebook.presto.execution.scheduler.StageScheduler;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.split.SplitSource;
import com.facebook.presto.sql.planner.NodePartitionMap;
import com.facebook.presto.sql.planner.NodePartitioningManager;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.StageExecutionPlan;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.util.Failures;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.concurrent.MoreFutures;
import io.airlift.concurrent.SetThreadName;
import io.airlift.units.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

public class SqlQueryScheduler {
    private final QueryStateMachine queryStateMachine;
    private final ExecutionPolicy executionPolicy;
    private final Map<StageId, SqlStageExecution> stages;
    private final ExecutorService executor;
    private final StageId rootStageId;
    private final Map<StageId, StageScheduler> stageSchedulers;
    private final Map<StageId, StageLinkage> stageLinkages;
    private final boolean summarizeTaskInfo;
    private final AtomicBoolean started = new AtomicBoolean();

    public SqlQueryScheduler(QueryStateMachine queryStateMachine, LocationFactory locationFactory, StageExecutionPlan plan, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, RemoteTaskFactory remoteTaskFactory, Session session, boolean summarizeTaskInfo, int splitBatchSize, ExecutorService executor, OutputBuffers rootOutputBuffers, NodeTaskMap nodeTaskMap, ExecutionPolicy executionPolicy) {
        this.queryStateMachine = Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
        this.executionPolicy = Objects.requireNonNull(executionPolicy, "schedulerPolicyFactory is null");
        this.summarizeTaskInfo = summarizeTaskInfo;
        ImmutableMap.Builder stageSchedulers = ImmutableMap.builder();
        ImmutableMap.Builder stageLinkages = ImmutableMap.builder();
        HashMap partitioningCache = new HashMap();
        List<SqlStageExecution> stages = this.createStages(Optional.empty(), new AtomicInteger(), locationFactory, plan.withBucketToPartition(Optional.of(new int[1])), nodeScheduler, remoteTaskFactory, session, splitBatchSize, partitioningHandle -> partitioningCache.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap(session, (PartitioningHandle)handle)), executor, nodeTaskMap, (ImmutableMap.Builder<StageId, StageScheduler>)stageSchedulers, (ImmutableMap.Builder<StageId, StageLinkage>)stageLinkages);
        SqlStageExecution rootStage = stages.get(0);
        rootStage.setOutputBuffers(rootOutputBuffers);
        this.rootStageId = rootStage.getStageId();
        this.stages = (Map)stages.stream().collect(ImmutableCollectors.toImmutableMap(SqlStageExecution::getStageId));
        this.stageSchedulers = stageSchedulers.build();
        this.stageLinkages = stageLinkages.build();
        this.executor = executor;
        rootStage.addStateChangeListener(state -> {
            if (state == StageState.FINISHED) {
                queryStateMachine.transitionToFinishing();
            } else if (state == StageState.CANCELED) {
                queryStateMachine.transitionToFailed((Throwable)new PrestoException((ErrorCodeSupplier)StandardErrorCode.USER_CANCELED, "Query was canceled"));
            }
        });
        for (SqlStageExecution stage : stages) {
            stage.addStateChangeListener(state -> {
                if (queryStateMachine.isDone()) {
                    return;
                }
                if (state == StageState.FAILED) {
                    queryStateMachine.transitionToFailed(stage.getStageInfo().getFailureCause().toException());
                } else if (state == StageState.ABORTED) {
                    queryStateMachine.transitionToFailed((Throwable)new PrestoException((ErrorCodeSupplier)StandardErrorCode.INTERNAL_ERROR, "Query stage was aborted"));
                } else if (queryStateMachine.getQueryState() == QueryState.STARTING && stage.hasTasks()) {
                    queryStateMachine.transitionToRunning();
                }
            });
        }
    }

    private List<SqlStageExecution> createStages(Optional<SqlStageExecution> parent, AtomicInteger nextStageId, LocationFactory locationFactory, StageExecutionPlan plan, NodeScheduler nodeScheduler, RemoteTaskFactory remoteTaskFactory, Session session, int splitBatchSize, Function<PartitioningHandle, NodePartitionMap> partitioningCache, ExecutorService executor, NodeTaskMap nodeTaskMap, ImmutableMap.Builder<StageId, StageScheduler> stageSchedulers, ImmutableMap.Builder<StageId, StageLinkage> stageLinkages) {
        Optional<int[]> bucketToPartition;
        ImmutableList.Builder stages = ImmutableList.builder();
        StageId stageId = new StageId(this.queryStateMachine.getQueryId(), String.valueOf(nextStageId.getAndIncrement()));
        SqlStageExecution stage = new SqlStageExecution(stageId, locationFactory.createStageLocation(stageId), plan.getFragment(), remoteTaskFactory, session, this.summarizeTaskInfo, nodeTaskMap, executor);
        stages.add((Object)stage);
        PartitioningHandle partitioningHandle = plan.getFragment().getPartitioning();
        if (partitioningHandle.equals(SystemPartitioningHandle.SOURCE_DISTRIBUTION)) {
            SplitSource splitSource = plan.getDataSource().get();
            NodeSelector nodeSelector = nodeScheduler.createNodeSelector(splitSource.getDataSourceName());
            DynamicSplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeSelector, stage::getAllTasks);
            stageSchedulers.put((Object)stageId, (Object)new SourcePartitionedScheduler(stage, splitSource, placementPolicy, splitBatchSize));
            bucketToPartition = Optional.of(new int[1]);
        } else {
            NodePartitionMap nodePartitionMap = partitioningCache.apply(plan.getFragment().getPartitioning());
            if (plan.getDataSource().isPresent()) {
                stageSchedulers.put((Object)stageId, (Object)new FixedSourcePartitionedScheduler(stage, plan.getDataSource().get(), nodePartitionMap, splitBatchSize, nodeScheduler.createNodeSelector(null)));
                bucketToPartition = Optional.of(nodePartitionMap.getBucketToPartition());
            } else {
                Map<Integer, Node> partitionToNode = nodePartitionMap.getPartitionToNode();
                Failures.checkCondition(!partitionToNode.isEmpty(), (ErrorCodeSupplier)StandardErrorCode.NO_NODES_AVAILABLE, "No worker nodes available", new Object[0]);
                stageSchedulers.put((Object)stageId, (Object)new FixedCountScheduler(stage, partitionToNode));
                bucketToPartition = Optional.of(nodePartitionMap.getBucketToPartition());
            }
        }
        ImmutableSet.Builder childStagesBuilder = ImmutableSet.builder();
        for (StageExecutionPlan subStagePlan : plan.getSubStages()) {
            List<SqlStageExecution> subTree = this.createStages(Optional.of(stage), nextStageId, locationFactory, subStagePlan.withBucketToPartition(bucketToPartition), nodeScheduler, remoteTaskFactory, session, splitBatchSize, partitioningCache, executor, nodeTaskMap, stageSchedulers, stageLinkages);
            stages.addAll(subTree);
            SqlStageExecution childStage = subTree.get(0);
            childStagesBuilder.add((Object)childStage);
        }
        ImmutableSet childStages = childStagesBuilder.build();
        stage.addStateChangeListener(arg_0 -> SqlQueryScheduler.lambda$createStages$4((Set)childStages, arg_0));
        stageLinkages.put((Object)stageId, (Object)new StageLinkage(plan.getFragment().getId(), parent, (Set<SqlStageExecution>)childStages));
        return stages.build();
    }

    public StageInfo getStageInfo() {
        Map stageInfos = (Map)this.stages.values().stream().map(SqlStageExecution::getStageInfo).collect(ImmutableCollectors.toImmutableMap(StageInfo::getStageId));
        return this.buildStageInfo(this.rootStageId, stageInfos);
    }

    private StageInfo buildStageInfo(StageId stageId, Map<StageId, StageInfo> stageInfos) {
        StageInfo parent = stageInfos.get(stageId);
        Preconditions.checkArgument((parent != null ? 1 : 0) != 0, (String)"No stageInfo for %s", (Object[])new Object[]{parent});
        List childStages = (List)this.stageLinkages.get(stageId).getChildStageIds().stream().map(childStageId -> this.buildStageInfo((StageId)childStageId, stageInfos)).collect(ImmutableCollectors.toImmutableList());
        if (childStages.isEmpty()) {
            return parent;
        }
        return new StageInfo(parent.getStageId(), parent.getState(), parent.getSelf(), parent.getPlan(), parent.getTypes(), parent.getStageStats(), parent.getTasks(), childStages, parent.getFailureCause());
    }

    public long getTotalMemoryReservation() {
        return this.stages.values().stream().mapToLong(SqlStageExecution::getMemoryReservation).sum();
    }

    public Duration getTotalCpuTime() {
        long millis = this.stages.values().stream().mapToLong(stage -> stage.getTotalCpuTime().toMillis()).sum();
        return new Duration((double)millis, TimeUnit.MILLISECONDS);
    }

    public void start() {
        if (this.started.compareAndSet(false, true)) {
            this.executor.submit(this::schedule);
        }
    }

    private void schedule() {
        RuntimeException closeError;
        try {
            SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});
            Object object = null;
            try {
                HashSet<StageId> completedStages = new HashSet<StageId>();
                ExecutionSchedule executionSchedule = this.executionPolicy.createExecutionSchedule(this.stages.values());
                while (!executionSchedule.isFinished()) {
                    ArrayList blockedStages = new ArrayList();
                    for (SqlStageExecution sqlStageExecution : executionSchedule.getStagesToSchedule()) {
                        sqlStageExecution.beginScheduling();
                        ScheduleResult result = this.stageSchedulers.get(sqlStageExecution.getStageId()).schedule();
                        if (result.isFinished()) {
                            sqlStageExecution.schedulingComplete();
                        } else if (!result.getBlocked().isDone()) {
                            blockedStages.add(result.getBlocked());
                        }
                        this.stageLinkages.get(sqlStageExecution.getStageId()).processScheduleResults(sqlStageExecution.getState(), result.getNewTasks());
                    }
                    for (SqlStageExecution sqlStageExecution : this.stages.values()) {
                        if (completedStages.contains(sqlStageExecution.getStageId()) || !sqlStageExecution.getState().isDone()) continue;
                        this.stageLinkages.get(sqlStageExecution.getStageId()).processScheduleResults(sqlStageExecution.getState(), (Set<RemoteTask>)ImmutableSet.of());
                        completedStages.add(sqlStageExecution.getStageId());
                    }
                    if (blockedStages.isEmpty()) continue;
                    MoreFutures.tryGetFutureValue((Future)MoreFutures.firstCompletedFuture(blockedStages), (int)100, (TimeUnit)TimeUnit.MILLISECONDS);
                    for (CompletableFuture completableFuture : blockedStages) {
                        completableFuture.cancel(true);
                    }
                }
                for (SqlStageExecution stage : this.stages.values()) {
                    StageState stageState = stage.getState();
                    if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING || stageState.isDone()) continue;
                    throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.INTERNAL_ERROR, String.format("Scheduling is complete, but stage %s is in state %s", new Object[]{stage.getStageId(), stageState}));
                }
            }
            catch (Throwable completedStages) {
                object = completedStages;
                throw completedStages;
            }
            finally {
                if (ignored != null) {
                    if (object != null) {
                        try {
                            ignored.close();
                        }
                        catch (Throwable completedStages) {
                            ((Throwable)object).addSuppressed(completedStages);
                        }
                    } else {
                        ignored.close();
                    }
                }
            }
            closeError = new RuntimeException();
        }
        catch (Throwable t) {
            try {
                this.queryStateMachine.transitionToFailed(t);
                throw Throwables.propagate((Throwable)t);
            }
            catch (Throwable throwable) {
                RuntimeException closeError2 = new RuntimeException();
                for (StageScheduler scheduler : this.stageSchedulers.values()) {
                    try {
                        scheduler.close();
                    }
                    catch (Throwable t2) {
                        this.queryStateMachine.transitionToFailed(t2);
                        closeError2.addSuppressed(t2);
                    }
                }
                if (closeError2.getSuppressed().length > 0) {
                    throw closeError2;
                }
                throw throwable;
            }
        }
        for (StageScheduler scheduler : this.stageSchedulers.values()) {
            try {
                scheduler.close();
            }
            catch (Throwable t) {
                this.queryStateMachine.transitionToFailed(t);
                closeError.addSuppressed(t);
            }
        }
        if (closeError.getSuppressed().length > 0) {
            throw closeError;
        }
    }

    public void cancelStage(StageId stageId) {
        try (SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});){
            SqlStageExecution sqlStageExecution = this.stages.get(stageId);
            SqlStageExecution stage = Objects.requireNonNull(sqlStageExecution, () -> String.format("Stage %s does not exist", stageId));
            stage.cancel();
        }
    }

    public void abort() {
        try (SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});){
            this.stages.values().stream().forEach(SqlStageExecution::abort);
        }
    }

    private static /* synthetic */ void lambda$createStages$4(Set childStages, StageState newState) {
        if (newState.isDone()) {
            childStages.stream().forEach(SqlStageExecution::cancel);
        }
    }

    private static class StageLinkage {
        private final PlanFragmentId currentStageFragmentId;
        private final Optional<SqlStageExecution> parent;
        private final Set<OutputBufferManager> childOutputBufferManagers;
        private final Set<StageId> childStageIds;

        public StageLinkage(PlanFragmentId fragmentId, Optional<SqlStageExecution> parent, Set<SqlStageExecution> children) {
            this.currentStageFragmentId = fragmentId;
            this.parent = parent;
            this.childOutputBufferManagers = (Set)children.stream().map(childStage -> {
                if (childStage.getFragment().getPartitionFunction().getPartitioningHandle().equals(SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION)) {
                    return new BroadcastOutputBufferManager(childStage::setOutputBuffers);
                }
                return new PartitionedOutputBufferManager(childStage::setOutputBuffers);
            }).collect(ImmutableCollectors.toImmutableSet());
            this.childStageIds = (Set)children.stream().map(SqlStageExecution::getStageId).collect(ImmutableCollectors.toImmutableSet());
        }

        public Set<StageId> getChildStageIds() {
            return this.childStageIds;
        }

        public void processScheduleResults(StageState newState, Set<RemoteTask> newTasks) {
            boolean noMoreTasks = false;
            switch (newState) {
                case PLANNED: 
                case SCHEDULING: {
                    break;
                }
                case SCHEDULING_SPLITS: 
                case SCHEDULED: 
                case RUNNING: 
                case FINISHED: 
                case CANCELED: {
                    noMoreTasks = true;
                }
            }
            if (this.parent.isPresent()) {
                Set newExchangeLocations = (Set)newTasks.stream().map(task -> task.getTaskInfo().getSelf()).collect(ImmutableCollectors.toImmutableSet());
                this.parent.get().addExchangeLocations(this.currentStageFragmentId, newExchangeLocations, noMoreTasks);
            }
            if (!this.childOutputBufferManagers.isEmpty()) {
                List newOutputBuffers = (List)newTasks.stream().map(task -> new OutputBufferManager.OutputBuffer(task.getTaskId(), task.getPartition())).collect(ImmutableCollectors.toImmutableList());
                for (OutputBufferManager child : this.childOutputBufferManagers) {
                    child.addOutputBuffers(newOutputBuffers, noMoreTasks);
                }
            }
        }
    }
}

