package org.apache.wayang.core.plan.executionplan;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.wayang.core.optimizer.enumeration.ExecutionTaskFlow;
import org.apache.wayang.core.optimizer.enumeration.StageAssignmentTraversal;
import org.apache.wayang.core.util.Counter;

/* loaded from: input_file:org/apache/wayang/core/plan/executionplan/ExecutionPlan.class */
public class ExecutionPlan {
    private final Logger logger = LogManager.getLogger(getClass());
    private Collection<ExecutionStage> startingStages = new LinkedList();
    static final /* synthetic */ boolean $assertionsDisabled;

    public void addStartingStage(ExecutionStage executionStage) {
        this.startingStages.add(executionStage);
    }

    public Collection<ExecutionStage> getStartingStages() {
        return this.startingStages;
    }

    public String toExtensiveString() {
        return toExtensiveString(false);
    }

    public String toExtensiveString(boolean z) {
        StringBuilder sb = new StringBuilder();
        Counter counter = new Counter();
        LinkedList linkedList = new LinkedList(this.startingStages);
        HashSet hashSet = new HashSet();
        while (!linkedList.isEmpty()) {
            while (!linkedList.isEmpty()) {
                ExecutionStage executionStage = (ExecutionStage) linkedList.poll();
                if (hashSet.add(executionStage)) {
                    sb.append(">>> ").append(executionStage).append(":\n");
                    executionStage.getPlanAsString(sb, "> ");
                    sb.append("\n");
                    for (ExecutionStage executionStage2 : executionStage.getSuccessors()) {
                        int add = counter.add(executionStage2, 1);
                        if (!z || add == executionStage2.getPredecessors().size() || executionStage2.isLoopHead()) {
                            linkedList.add(executionStage2);
                        }
                    }
                }
            }
        }
        return sb.toString();
    }

    public List<Map> toJsonList() {
        Counter counter = new Counter();
        LinkedList linkedList = new LinkedList(this.startingStages);
        HashSet hashSet = new HashSet();
        ArrayList arrayList = new ArrayList();
        while (!linkedList.isEmpty()) {
            ExecutionStage executionStage = (ExecutionStage) linkedList.poll();
            if (hashSet.add(executionStage)) {
                Map jsonMap = executionStage.toJsonMap();
                jsonMap.put("sequence_number", Integer.valueOf(arrayList.size()));
                arrayList.add(jsonMap);
                for (ExecutionStage executionStage2 : executionStage.getSuccessors()) {
                    if (counter.add(executionStage2, 1) == executionStage2.getPredecessors().size() || executionStage2.isLoopHead()) {
                        linkedList.add(executionStage2);
                    }
                }
            }
        }
        return arrayList;
    }

    public Set<Channel> retain(Set<ExecutionStage> set) {
        HashSet hashSet = new HashSet();
        for (ExecutionStage executionStage : set) {
            for (Channel channel : executionStage.getOutboundChannels()) {
                if (channel.retain(set)) {
                    hashSet.add(channel);
                }
            }
            executionStage.retainSuccessors(set);
            executionStage.getPlatformExecution().retain(set);
        }
        return hashSet;
    }

    public Set<ExecutionStage> getStages() {
        HashSet hashSet = new HashSet();
        LinkedList linkedList = new LinkedList(getStartingStages());
        while (!linkedList.isEmpty()) {
            ExecutionStage executionStage = (ExecutionStage) linkedList.poll();
            if (hashSet.add(executionStage)) {
                linkedList.addAll(executionStage.getSuccessors());
            }
        }
        return hashSet;
    }

    public Set<ExecutionTask> collectAllTasks() {
        HashSet hashSet = new HashSet();
        Iterator<ExecutionStage> it = getStages().iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getAllTasks());
        }
        return hashSet;
    }

    public void expand(ExecutionPlan executionPlan) {
        for (Channel channel : executionPlan.getOpenInputChannels()) {
            channel.mergeIntoOriginal();
            Channel original = channel.getOriginal();
            ExecutionStage stage = original.getProducer().getStage();
            if (!$assertionsDisabled && stage == null) {
                throw new AssertionError(String.format("No stage found for %s.", original.getProducer()));
            }
            for (ExecutionTask executionTask : original.getConsumers()) {
                ExecutionStage stage2 = executionTask.getStage();
                if (!$assertionsDisabled && stage2 == null) {
                    throw new AssertionError(String.format("No stage found for %s.", executionTask));
                }
                if (stage != stage2) {
                    stage.addSuccessor(stage2);
                }
            }
        }
    }

    public Collection<Channel> getOpenInputChannels() {
        return (Collection) collectAllTasks().stream().flatMap(executionTask -> {
            return Arrays.stream(executionTask.getInputChannels());
        }).filter((v0) -> {
            return v0.isCopy();
        }).collect(Collectors.toList());
    }

    public boolean isSane() {
        Set<ExecutionTask> collectAllTasks = collectAllTasks();
        boolean allMatch = collectAllTasks.stream().allMatch(executionTask -> {
            return executionTask.getStage() != null;
        });
        if (!allMatch) {
            this.logger.error("There are tasks without stages.");
        }
        Set<Channel> set = (Set) collectAllTasks.stream().flatMap(executionTask2 -> {
            return Stream.concat(Arrays.stream(executionTask2.getInputChannels()), Arrays.stream(executionTask2.getOutputChannels()));
        }).collect(Collectors.toSet());
        boolean allMatch2 = set.stream().allMatch(channel -> {
            return !channel.isCopy();
        });
        if (!allMatch2) {
            this.logger.error("There are channels that are copies.");
        }
        boolean z = true;
        for (Channel channel2 : set) {
            for (Channel channel3 : channel2.getSiblings()) {
                if (!set.contains(channel3)) {
                    this.logger.error("A sibling of {}, namely {}, seems to be invalid.", channel2, channel3);
                    z = false;
                }
            }
        }
        return allMatch && allMatch2 && z;
    }

    public static ExecutionPlan createFrom(ExecutionTaskFlow executionTaskFlow, StageAssignmentTraversal.StageSplittingCriterion stageSplittingCriterion) {
        return StageAssignmentTraversal.assignStages(executionTaskFlow, stageSplittingCriterion);
    }

    static {
        $assertionsDisabled = !ExecutionPlan.class.desiredAssertionStatus();
    }
}
