package org.apache.wayang.core.plan.wayangplan;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.wayang.core.plan.wayangplan.OperatorAlternative;

/* loaded from: input_file:org/apache/wayang/core/plan/wayangplan/PlanTraversal.class */
public class PlanTraversal {
    private static final Logger logger;
    private final boolean isFollowInputs;
    private final boolean isFollowOutputs;
    static final /* synthetic */ boolean $assertionsDisabled;
    public Set<Operator> visitedRelevantOperators = new HashSet();
    public Set<Operator> visitedIrrelevantOperators = new HashSet();
    private Predicate<OperatorContainer> containerEnterCondition = operatorContainer -> {
        return false;
    };
    private Predicate<CompositeOperator> compositeRelevanceCondition = compositeOperator -> {
        return false;
    };
    private Callback traversalCallback = null;
    private Predicate<OutputSlot<?>> outputFollowPredicate = outputSlot -> {
        return true;
    };
    private Predicate<InputSlot<?>> inputFollowPredicate = inputSlot -> {
        return true;
    };
    private Predicate<InputSlot<?>> inputFollowDownstreamPredicate = inputSlot -> {
        return true;
    };

    @FunctionalInterface
    /* loaded from: input_file:org/apache/wayang/core/plan/wayangplan/PlanTraversal$Callback.class */
    public interface Callback {
        public static final Callback NOP = (operator, inputSlot, outputSlot) -> {
        };

        void traverse(Operator operator, InputSlot<?> inputSlot, OutputSlot<?> outputSlot);
    }

    @FunctionalInterface
    /* loaded from: input_file:org/apache/wayang/core/plan/wayangplan/PlanTraversal$SimpleCallback.class */
    public interface SimpleCallback {
        void traverse(Operator operator);
    }

    @Deprecated
    public PlanTraversal(boolean z, boolean z2) {
        this.isFollowInputs = z;
        this.isFollowOutputs = z2;
    }

    public static PlanTraversal fanOut() {
        return new PlanTraversal(true, true);
    }

    public static PlanTraversal downstream() {
        return new PlanTraversal(false, true);
    }

    public static PlanTraversal upstream() {
        return new PlanTraversal(true, false);
    }

    public PlanTraversal withCallback(Callback callback) {
        this.traversalCallback = callback;
        return this;
    }

    public PlanTraversal withCallback(SimpleCallback simpleCallback) {
        this.traversalCallback = (operator, inputSlot, outputSlot) -> {
            simpleCallback.traverse(operator);
        };
        return this;
    }

    public PlanTraversal followingInputsIf(Predicate<InputSlot<?>> predicate) {
        this.inputFollowPredicate = predicate;
        return this;
    }

    public PlanTraversal followingInputsDownstreamIf(Predicate<InputSlot<?>> predicate) {
        this.inputFollowDownstreamPredicate = predicate;
        return this;
    }

    public PlanTraversal followingOutputsIf(Predicate<OutputSlot<?>> predicate) {
        this.outputFollowPredicate = predicate;
        return this;
    }

    public PlanTraversal enteringContainersIf(Predicate<OperatorContainer> predicate) {
        this.containerEnterCondition = predicate;
        return this;
    }

    public PlanTraversal consideringEnteredOperatorsIf(Predicate<CompositeOperator> predicate) {
        this.compositeRelevanceCondition = predicate;
        return this;
    }

    public PlanTraversal traversingHierarchically() {
        return enteringContainersIf(operatorContainer -> {
            return true;
        }).consideringEnteredOperatorsIf(compositeOperator -> {
            return false;
        });
    }

    public PlanTraversal traversingFlat() {
        return enteringContainersIf(operatorContainer -> {
            return false;
        });
    }

    public PlanTraversal traverse(Collection<? extends Operator> collection) {
        collection.forEach(this::traverse);
        return this;
    }

    public PlanTraversal traverse(Stream<? extends Operator> stream) {
        stream.forEach(this::traverse);
        return this;
    }

    public PlanTraversal traverse(Operator operator) {
        return traverse(operator, null, null);
    }

    public PlanTraversal traverseFocused(Operator operator, Collection<OutputSlot<?>> collection) {
        this.visitedRelevantOperators.add(operator);
        if (!$assertionsDisabled && !collection.stream().allMatch(outputSlot -> {
            return outputSlot.getOwner() == operator;
        })) {
            throw new AssertionError();
        }
        followOutputs(collection.stream());
        return this;
    }

    public PlanTraversal traverse(Operator operator, InputSlot<?> inputSlot, OutputSlot<?> outputSlot) {
        if (visit(operator, inputSlot, outputSlot)) {
            if (this.isFollowInputs) {
                followInputs(operator);
            }
            if (this.isFollowOutputs) {
                followOutputs(operator);
            }
        }
        return this;
    }

    private boolean visit(Operator operator, InputSlot<?> inputSlot, OutputSlot<?> outputSlot) {
        boolean z = false;
        if (!operator.isElementary()) {
            for (OperatorContainer operatorContainer : ((CompositeOperator) operator).getContainers()) {
                if (this.containerEnterCondition.test(operatorContainer)) {
                    enter(operatorContainer, inputSlot, outputSlot);
                    z = true;
                }
            }
        }
        boolean z2 = !z || this.compositeRelevanceCondition.test((CompositeOperator) operator);
        boolean add = (z2 ? this.visitedRelevantOperators : this.visitedIrrelevantOperators).add(operator);
        if (add && z2 && this.traversalCallback != null) {
            this.traversalCallback.traverse(operator, inputSlot, outputSlot);
        }
        return add;
    }

    private boolean traverseHierarchical(Operator operator, InputSlot<?> inputSlot, OutputSlot<?> outputSlot) {
        if (operator.isSubplan()) {
            enter((Subplan) operator, inputSlot, outputSlot);
        } else if (operator.isAlternative()) {
            Iterator<OperatorAlternative.Alternative> it = ((OperatorAlternative) operator).getAlternatives().iterator();
            while (it.hasNext()) {
                enter(it.next(), inputSlot, outputSlot);
            }
            return true;
        }
        if ($assertionsDisabled || operator.isElementary()) {
            return false;
        }
        throw new AssertionError(String.format("Unknown composite operator: %s", operator));
    }

    private void enter(OperatorContainer operatorContainer, InputSlot<?> inputSlot, OutputSlot<?> outputSlot) {
        if (inputSlot != null) {
            for (InputSlot<?> inputSlot2 : operatorContainer.followInput(inputSlot.unchecked())) {
                traverse(inputSlot2.getOwner(), inputSlot2, null);
            }
            return;
        }
        if (outputSlot != null) {
            OutputSlot<?> traceOutput = operatorContainer.traceOutput(outputSlot.unchecked());
            traverse(traceOutput.getOwner(), null, traceOutput);
        } else if (operatorContainer.isSink()) {
            traverse(operatorContainer.getSink(), null, null);
        } else if (operatorContainer.isSource()) {
            traverse(operatorContainer.getSource(), null, null);
        } else {
            logger.warn("Could not enter {} during hierarchical traversal.", operatorContainer);
        }
    }

    protected void followInputs(Operator operator) {
        Arrays.stream(operator.getAllInputs()).filter(this.inputFollowPredicate).map((v0) -> {
            return v0.getOccupant();
        }).filter(outputSlot -> {
            return outputSlot != null;
        }).forEach(outputSlot2 -> {
            traverse(outputSlot2.getOwner(), null, outputSlot2);
        });
    }

    private void followOutputs(Operator operator) {
        followOutputs(Arrays.stream(operator.getAllOutputs()));
    }

    protected void followOutputs(Stream<OutputSlot<?>> stream) {
        stream.filter(this.outputFollowPredicate).map(outputSlot -> {
            return outputSlot.getOccupiedSlots();
        }).flatMap((v0) -> {
            return v0.stream();
        }).filter(inputSlot -> {
            return inputSlot != null;
        }).filter(this.inputFollowDownstreamPredicate).forEach(inputSlot2 -> {
            traverse(inputSlot2.getOwner(), inputSlot2, null);
        });
    }

    public Collection<Operator> getTraversedNodesWith(Predicate<Operator> predicate) {
        return (Collection) this.visitedRelevantOperators.stream().filter(predicate).collect(Collectors.toList());
    }

    public Collection<Operator> getTraversedNodes() {
        return getTraversedNodesWith(operator -> {
            return true;
        });
    }

    static {
        $assertionsDisabled = !PlanTraversal.class.desiredAssertionStatus();
        logger = LogManager.getLogger(PlanTraversal.class);
    }
}
