package org.apache.wayang.core.optimizer.cardinality;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.optimizer.OptimizationContext;
import org.apache.wayang.core.plan.wayangplan.InputSlot;
import org.apache.wayang.core.plan.wayangplan.LoopHeadOperator;
import org.apache.wayang.core.plan.wayangplan.LoopSubplan;
import org.apache.wayang.core.plan.wayangplan.OutputSlot;

/* loaded from: input_file:org/apache/wayang/core/optimizer/cardinality/LoopSubplanCardinalityPusher.class */
public class LoopSubplanCardinalityPusher extends CardinalityPusher {
    private final CardinalityPusher loopHeadIterationPusher;
    private final CardinalityEstimationTraversal bodyTraversal;
    private final CardinalityPusher loopHeadInitializationPusher;
    private final CardinalityPusher loopHeadFinalizationPusher;
    private final Set<OutputSlot<?>> bodyOutputSlots;

    public LoopSubplanCardinalityPusher(LoopSubplan loopSubplan, Configuration configuration) {
        super(loopSubplan);
        LoopHeadOperator loopHead = loopSubplan.getLoopHead();
        this.loopHeadInitializationPusher = loopHead.getInitializationPusher(configuration);
        this.loopHeadIterationPusher = loopHead.getCardinalityPusher(configuration);
        this.loopHeadFinalizationPusher = loopHead.getFinalizationPusher(configuration);
        Set set = (Set) Arrays.stream(loopSubplan.getAllInputs()).flatMap(inputSlot -> {
            return loopSubplan.followInput(inputSlot).stream();
        }).collect(Collectors.toSet());
        Iterator<InputSlot<?>> it = loopHead.getLoopInitializationInputs().iterator();
        while (it.hasNext()) {
            set.remove(it.next());
        }
        Iterator<OutputSlot<?>> it2 = loopHead.getLoopBodyOutputs().iterator();
        while (it2.hasNext()) {
            Iterator<InputSlot<?>> it3 = it2.next().getOccupiedSlots().iterator();
            while (it3.hasNext()) {
                set.add(it3.next());
            }
        }
        this.bodyOutputSlots = (Set) loopHead.getLoopBodyInputs().stream().map((v0) -> {
            return v0.getOccupant();
        }).filter((v0) -> {
            return Objects.nonNull(v0);
        }).collect(Collectors.toSet());
        this.bodyTraversal = CardinalityEstimationTraversal.createPushTraversal(set, loopHead.getLoopBodyInputs(), Collections.emptyList(), configuration);
    }

    @Override // org.apache.wayang.core.optimizer.cardinality.CardinalityPusher
    protected void doPush(OptimizationContext.OperatorContext operatorContext, Configuration configuration) {
        OptimizationContext optimizationContext = operatorContext.getOptimizationContext();
        LoopSubplan loopSubplan = (LoopSubplan) operatorContext.getOperator();
        OptimizationContext.LoopContext nestedLoopContext = optimizationContext.getNestedLoopContext(loopSubplan);
        LoopHeadOperator loopHead = loopSubplan.getLoopHead();
        for (OptimizationContext optimizationContext2 : nestedLoopContext.getIterationContexts()) {
            OptimizationContext.OperatorContext operatorContext2 = optimizationContext2.getOperatorContext(loopHead);
            if (optimizationContext2.isFinalIteration()) {
                this.loopHeadFinalizationPusher.push(operatorContext2, configuration);
                for (int i = 0; i < loopSubplan.getNumOutputs(); i++) {
                    OutputSlot traceOutput = loopSubplan.traceOutput(loopSubplan.getOutput(i));
                    if (traceOutput != null) {
                        operatorContext.setOutputCardinality(i, operatorContext2.getOutputCardinality(traceOutput.getIndex()));
                    }
                }
            } else {
                if (optimizationContext2.isInitialIteration()) {
                    this.loopHeadInitializationPusher.push(operatorContext2, configuration);
                } else {
                    this.loopHeadIterationPusher.push(operatorContext2, configuration);
                }
                Iterator<OutputSlot<?>> it = loopHead.getLoopBodyOutputs().iterator();
                while (it.hasNext()) {
                    operatorContext2.pushCardinalityForward(it.next().getIndex(), optimizationContext2);
                }
                this.bodyTraversal.traverse(optimizationContext2, configuration);
                for (OutputSlot<?> outputSlot : this.bodyOutputSlots) {
                    optimizationContext2.getOperatorContext(outputSlot.getOwner()).pushCardinalityForward(outputSlot.getIndex(), optimizationContext2.getNextIterationContext());
                }
            }
        }
    }
}
