package org.apache.wayang.core.plan.wayangplan;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.commons.lang3.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.wayang.core.util.OneTimeExecutable;

/* loaded from: input_file:org/apache/wayang/core/plan/wayangplan/LoopIsolator.class */
public class LoopIsolator extends OneTimeExecutable {
    private final WayangPlan wayangPlan;

    private LoopIsolator(WayangPlan wayangPlan) {
        this.wayangPlan = wayangPlan;
    }

    public static void isolateLoops(WayangPlan wayangPlan) {
        new LoopIsolator(wayangPlan).run();
    }

    private void run() {
        execute();
    }

    @Override // org.apache.wayang.core.util.OneTimeExecutable
    protected void doExecute() {
        if (this.wayangPlan.isLoopsIsolated()) {
            return;
        }
        PlanTraversal.upstream().traverse(this.wayangPlan.getSinks()).getTraversedNodesWith((v0) -> {
            return v0.isLoopHead();
        }).forEach((v0) -> {
            isolate(v0);
        });
        this.wayangPlan.setLoopsIsolated();
    }

    public static LoopSubplan isolate(Operator operator) {
        if (!operator.isLoopHead()) {
            return null;
        }
        LoopHeadOperator loopHeadOperator = (LoopHeadOperator) operator;
        Collection<InputSlot<?>> collectInboundInputs = collectInboundInputs(loopHeadOperator);
        ArrayList arrayList = new ArrayList(loopHeadOperator.getLoopInitializationInputs().size() + collectInboundInputs.size());
        arrayList.addAll(loopHeadOperator.getLoopInitializationInputs());
        arrayList.addAll(collectInboundInputs);
        return LoopSubplan.wrap(loopHeadOperator, arrayList, new ArrayList(loopHeadOperator.getFinalLoopOutputs()));
    }

    private static Collection<InputSlot<?>> collectInboundInputs(LoopHeadOperator loopHeadOperator) {
        Collection<Operator> traversedNodes = PlanTraversal.downstream().traverseFocused(loopHeadOperator, loopHeadOperator.getLoopBodyOutputs()).getTraversedNodes();
        HashSet hashSet = new HashSet();
        for (Operator operator : traversedNodes) {
            if (operator != loopHeadOperator) {
                for (InputSlot<?> inputSlot : operator.getAllInputs()) {
                    OutputSlot<?> occupant = inputSlot.getOccupant();
                    if (occupant != null && !traversedNodes.contains(occupant.getOwner())) {
                        hashSet.add(inputSlot);
                    }
                }
                Validate.isTrue(!operator.isSink(), "Disallowed sink %s in loop body of %s.", new Object[]{operator, loopHeadOperator});
            }
        }
        for (InputSlot<?> inputSlot2 : loopHeadOperator.getLoopBodyInputs()) {
            Validate.notNull(inputSlot2.getOccupant(), "Loop body input %s is unconnected.", new Object[]{inputSlot2});
            Validate.isTrue(traversedNodes.contains(inputSlot2.getOccupant().getOwner()), "Illegal input for loop head input %s.", new Object[]{inputSlot2});
        }
        for (InputSlot<?> inputSlot3 : loopHeadOperator.getLoopInitializationInputs()) {
            if (inputSlot3.getOccupant() != null) {
                Validate.isTrue(!traversedNodes.contains(inputSlot3.getOccupant().getOwner()), "Illegal input for loop head input %s.", new Object[]{inputSlot3});
            }
        }
        for (OutputSlot<?> outputSlot : loopHeadOperator.getLoopBodyOutputs()) {
            if (outputSlot.getOccupiedSlots().isEmpty()) {
                LogManager.getLogger(LoopIsolator.class).warn("{} is not feeding any input slot.", outputSlot);
            }
        }
        Iterator<OutputSlot<?>> it = loopHeadOperator.getFinalLoopOutputs().iterator();
        while (it.hasNext()) {
            for (InputSlot<?> inputSlot4 : it.next().getOccupiedSlots()) {
                Validate.isTrue(!traversedNodes.contains(inputSlot4.getOwner()), "%s is inside and outside the loop body of %s.", new Object[]{inputSlot4.getOccupant(), loopHeadOperator});
            }
        }
        return hashSet;
    }
}
