package org.apache.wayang.core.optimizer.enumeration;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.wayang.commons.util.profiledb.model.measurement.TimeMeasurement;
import org.apache.wayang.core.optimizer.OptimizationContext;
import org.apache.wayang.core.optimizer.channels.ChannelConversionGraph;
import org.apache.wayang.core.optimizer.enumeration.PlanImplementation;
import org.apache.wayang.core.plan.executionplan.Channel;
import org.apache.wayang.core.plan.executionplan.ExecutionTask;
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;
import org.apache.wayang.core.plan.wayangplan.InputSlot;
import org.apache.wayang.core.plan.wayangplan.Operator;
import org.apache.wayang.core.plan.wayangplan.OperatorAlternative;
import org.apache.wayang.core.plan.wayangplan.OutputSlot;
import org.apache.wayang.core.platform.Junction;
import org.apache.wayang.core.util.MultiMap;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.core.util.WayangCollections;

/* loaded from: input_file:org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.class */
public class PlanEnumeration {
    private static final Logger logger;
    final Set<OperatorAlternative> scope;
    final Set<InputSlot<?>> requestedInputSlots;
    final Set<Tuple<OutputSlot<?>, InputSlot<?>>> servingOutputSlots;
    final Collection<PlanImplementation> planImplementations;
    final Map<ExecutionOperator, ExecutionTask> executedTasks;
    static final /* synthetic */ boolean $assertionsDisabled;

    public PlanEnumeration() {
        this(new HashSet(), new HashSet(), new HashSet());
    }

    private PlanEnumeration(Set<OperatorAlternative> set, Set<InputSlot<?>> set2, Set<Tuple<OutputSlot<?>, InputSlot<?>>> set3) {
        this(set, set2, set3, new LinkedList(), new HashMap());
    }

    private PlanEnumeration(Set<OperatorAlternative> set, Set<InputSlot<?>> set2, Set<Tuple<OutputSlot<?>, InputSlot<?>>> set3, Collection<PlanImplementation> collection, Map<ExecutionOperator, ExecutionTask> map) {
        this.scope = set;
        this.requestedInputSlots = set2;
        this.servingOutputSlots = set3;
        this.planImplementations = collection;
        this.executedTasks = map;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static PlanEnumeration createSingleton(ExecutionOperator executionOperator, OptimizationContext optimizationContext) {
        PlanEnumeration createFor = createFor(executionOperator, executionOperator);
        createFor.add(createFor.createSingletonPartialPlan(executionOperator, optimizationContext));
        return createFor;
    }

    static PlanEnumeration createFor(Operator operator, Operator operator2) {
        return createFor(operator, inputSlot -> {
            return true;
        }, operator2, outputSlot -> {
            return true;
        });
    }

    static PlanEnumeration createFor(Operator operator, Predicate<InputSlot<?>> predicate, Operator operator2, Predicate<OutputSlot<?>> predicate2) {
        PlanEnumeration planEnumeration = new PlanEnumeration();
        for (InputSlot<?> inputSlot : operator.getAllInputs()) {
            if (predicate.test(inputSlot)) {
                planEnumeration.requestedInputSlots.add(inputSlot);
            }
        }
        for (OutputSlot<?> outputSlot : operator2.getAllOutputs()) {
            if (predicate2.test(outputSlot)) {
                List<InputSlot<?>> occupiedSlots = outputSlot.getOccupiedSlots();
                if (occupiedSlots.isEmpty()) {
                    occupiedSlots = Collections.singletonList(null);
                }
                Iterator<InputSlot<?>> it = occupiedSlots.iterator();
                while (it.hasNext()) {
                    planEnumeration.servingOutputSlots.add(new Tuple<>(outputSlot, it.next()));
                }
            }
        }
        return planEnumeration;
    }

    private static void assertMatchingInterface(PlanEnumeration planEnumeration, PlanEnumeration planEnumeration2) {
        if (!planEnumeration.requestedInputSlots.equals(planEnumeration2.requestedInputSlots)) {
            throw new IllegalArgumentException("Input slots are not matching.");
        }
        if (!planEnumeration.servingOutputSlots.equals(planEnumeration2.servingOutputSlots)) {
            throw new IllegalArgumentException("Output slots are not matching.");
        }
    }

    public PlanEnumeration concatenate(OutputSlot<?> outputSlot, Collection<Channel> collection, Map<InputSlot<?>, PlanEnumeration> map, OptimizationContext optimizationContext, TimeMeasurement timeMeasurement) {
        if (!$assertionsDisabled) {
            Stream<R> map2 = getServingOutputSlots().stream().map((v0) -> {
                return v0.getField0();
            });
            outputSlot.getClass();
            if (!map2.anyMatch((v1) -> {
                return r1.equals(v1);
            })) {
                throw new AssertionError(String.format("Cannot concatenate %s: it is not a served output.", outputSlot));
            }
        }
        if (!$assertionsDisabled && map.isEmpty()) {
            throw new AssertionError();
        }
        TimeMeasurement start = timeMeasurement == null ? null : timeMeasurement.start(new String[]{"Concatenation"});
        if (logger.isInfoEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append("Concatenating ").append(getPlanImplementations().size());
            Iterator<PlanEnumeration> it = map.values().iterator();
            while (it.hasNext()) {
                sb.append("x").append(it.next().getPlanImplementations().size());
            }
            sb.append(" plan implementations.");
            logger.debug(sb.toString());
        }
        PlanEnumeration planEnumeration = new PlanEnumeration();
        planEnumeration.scope.addAll(getScope());
        planEnumeration.requestedInputSlots.addAll(getRequestedInputSlots());
        planEnumeration.servingOutputSlots.addAll(getServingOutputSlots());
        planEnumeration.executedTasks.putAll(getExecutedTasks());
        for (Map.Entry<InputSlot<?>, PlanEnumeration> entry : map.entrySet()) {
            entry.getKey();
            PlanEnumeration value = entry.getValue();
            planEnumeration.scope.addAll(value.getScope());
            planEnumeration.requestedInputSlots.addAll(value.getRequestedInputSlots());
            planEnumeration.servingOutputSlots.addAll(value.getServingOutputSlots());
            planEnumeration.executedTasks.putAll(value.getExecutedTasks());
        }
        planEnumeration.requestedInputSlots.removeAll(map.keySet());
        planEnumeration.servingOutputSlots.removeIf(tuple -> {
            return ((OutputSlot) tuple.getField0()).equals(outputSlot);
        });
        planEnumeration.planImplementations.addAll(concatenatePartialPlans(outputSlot, collection, map, optimizationContext, planEnumeration, start));
        logger.debug("Created {} plan implementations.", Integer.valueOf(planEnumeration.getPlanImplementations().size()));
        if (start != null) {
            start.stop();
        }
        return planEnumeration;
    }

    private Collection<PlanImplementation> concatenatePartialPlans(OutputSlot<?> outputSlot, Collection<Channel> collection, Map<InputSlot<?>, PlanEnumeration> map, OptimizationContext optimizationContext, PlanEnumeration planEnumeration, TimeMeasurement timeMeasurement) {
        return concatenatePartialPlansBatchwise(outputSlot, collection, map, optimizationContext, optimizationContext.getJob().isRequestBreakpointFor(outputSlot, optimizationContext.getOperatorContext(outputSlot.getOwner())), planEnumeration, timeMeasurement);
    }

    private Collection<PlanImplementation> concatenatePartialPlansBatchwise(OutputSlot<?> outputSlot, Collection<Channel> collection, Map<InputSlot<?>, PlanEnumeration> map, OptimizationContext optimizationContext, boolean z, PlanEnumeration planEnumeration, TimeMeasurement timeMeasurement) {
        ChannelConversionGraph channelConversionGraph = optimizationContext.getChannelConversionGraph();
        LinkedList linkedList = new LinkedList();
        ArrayList<InputSlot> arrayList = new ArrayList(map.keySet());
        MultiMap multiMap = new MultiMap();
        for (Map.Entry<InputSlot<?>, PlanEnumeration> entry : map.entrySet()) {
            multiMap.putSingle(entry.getValue(), entry.getKey());
        }
        MultiMap multiMap2 = new MultiMap();
        MultiMap multiMap3 = new MultiMap();
        Iterator it = multiMap.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry entry2 = (Map.Entry) it.next();
            PlanEnumeration planEnumeration2 = (PlanEnumeration) entry2.getKey();
            OutputSlot<?> outputSlot2 = planEnumeration2 == this ? outputSlot : null;
            Set set = (Set) entry2.getValue();
            ArrayList arrayList2 = new ArrayList(arrayList.size());
            for (InputSlot inputSlot : arrayList) {
                arrayList2.add(set.contains(inputSlot) ? inputSlot : null);
            }
            for (PlanImplementation planImplementation : planEnumeration2.getPlanImplementations()) {
                PlanImplementation.ConcatenationDescriptor createConcatenationDescriptor = planImplementation.createConcatenationDescriptor(outputSlot2, arrayList2);
                multiMap3.putSingle(createConcatenationDescriptor.groupDescriptor, createConcatenationDescriptor);
                multiMap2.putSingle(planImplementation.getPlanEnumeration(), createConcatenationDescriptor.groupDescriptor);
            }
        }
        if (!multiMap.containsKey(this)) {
            ArrayList createNullFilledArrayList = WayangCollections.createNullFilledArrayList(arrayList.size());
            for (PlanImplementation planImplementation2 : getPlanImplementations()) {
                PlanImplementation.ConcatenationDescriptor createConcatenationDescriptor2 = planImplementation2.createConcatenationDescriptor(outputSlot, createNullFilledArrayList);
                multiMap3.putSingle(createConcatenationDescriptor2.groupDescriptor, createConcatenationDescriptor2);
                multiMap2.putSingle(planImplementation2.getPlanEnumeration(), createConcatenationDescriptor2.groupDescriptor);
            }
        }
        if (logger.isInfoEnabled()) {
            logger.info("Concatenating {}={} concatenation groups ({} -> {} inputs).", multiMap2.values().stream().map(set2 -> {
                return String.valueOf(set2.size());
            }).collect(Collectors.joining("*")), Integer.valueOf(multiMap2.values().stream().mapToInt((v0) -> {
                return v0.size();
            }).reduce(1, (i, i2) -> {
                return i * i2;
            })), outputSlot, Integer.valueOf(map.size()));
        }
        ArrayList arrayList3 = new ArrayList(multiMap2.keySet());
        arrayList3.remove(this);
        arrayList3.add(0, this);
        ArrayList arrayList4 = new ArrayList(arrayList3.size());
        Iterator it2 = arrayList3.iterator();
        while (it2.hasNext()) {
            arrayList4.add(multiMap2.get((PlanEnumeration) it2.next()));
        }
        for (List list : WayangCollections.streamedCrossProduct(arrayList4)) {
            PlanImplementation.ConcatenationGroupDescriptor concatenationGroupDescriptor = (PlanImplementation.ConcatenationGroupDescriptor) list.get(0);
            OutputSlot<?> outputSlot3 = concatenationGroupDescriptor.execOutput;
            Set set3 = (Set) multiMap3.get(concatenationGroupDescriptor);
            PlanImplementation planImplementation3 = ((PlanImplementation.ConcatenationDescriptor) WayangCollections.getAny(set3)).execOutputPlanImplementation;
            if (!$assertionsDisabled && ((Set) set3.stream().map(concatenationDescriptor -> {
                return concatenationDescriptor.execOutputPlanImplementation;
            }).map((v0) -> {
                return v0.getOptimizationContext();
            }).collect(Collectors.toSet())).size() != 1) {
                throw new AssertionError();
            }
            ArrayList arrayList5 = new ArrayList(arrayList.size());
            Iterator it3 = list.iterator();
            while (it3.hasNext()) {
                for (Set<InputSlot<?>> set4 : ((PlanImplementation.ConcatenationGroupDescriptor) it3.next()).execInputs) {
                    if (set4 != null) {
                        arrayList5.addAll(set4);
                    }
                }
            }
            Operator owner = outputSlot3.getOwner();
            if (!$assertionsDisabled && !owner.isExecutionOperator()) {
                throw new AssertionError(String.format("Expected execution operator, found %s.", owner));
            }
            TimeMeasurement start = timeMeasurement == null ? null : timeMeasurement.start(new String[]{"Channel Conversion"});
            Junction findMinimumCostJunction = (collection == null || collection.isEmpty()) ? channelConversionGraph.findMinimumCostJunction(outputSlot3, arrayList5, planImplementation3.getOptimizationContext(), z) : channelConversionGraph.findMinimumCostJunction(outputSlot3, collection, arrayList5, planImplementation3.getOptimizationContext());
            if (start != null) {
                start.stop();
            }
            if (findMinimumCostJunction != null) {
                for (List list2 : WayangCollections.streamedCrossProduct(WayangCollections.map(list, concatenationGroupDescriptor2 -> {
                    Set set5 = multiMap3.get(concatenationGroupDescriptor2);
                    HashSet hashSet = new HashSet(set5.size());
                    Iterator it4 = set5.iterator();
                    while (it4.hasNext()) {
                        hashSet.add(((PlanImplementation.ConcatenationDescriptor) it4.next()).getPlanImplementation());
                    }
                    return hashSet;
                }))) {
                    PlanImplementation planImplementation4 = (PlanImplementation) list2.get(0);
                    PlanImplementation concatenate = planImplementation4.concatenate(list2.subList(0, list2.size()), findMinimumCostJunction, planImplementation4, planEnumeration);
                    if (concatenate != null) {
                        linkedList.add(concatenate);
                    }
                }
            }
        }
        return linkedList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private MultiMap<OutputSlot<?>, Tuple<PlanImplementation, PlanImplementation>> groupImplementationsByOutput(OutputSlot<?> outputSlot) {
        MultiMap<OutputSlot<?>, Tuple<PlanImplementation, PlanImplementation>> multiMap = new MultiMap<>();
        for (PlanImplementation planImplementation : getPlanImplementations()) {
            Collection<Tuple<OutputSlot<?>, PlanImplementation>> findExecutionOperatorOutputWithContext = planImplementation.findExecutionOperatorOutputWithContext(outputSlot);
            Tuple tuple = (Tuple) WayangCollections.getSingleOrNull(findExecutionOperatorOutputWithContext);
            if (!$assertionsDisabled && (findExecutionOperatorOutputWithContext == null || findExecutionOperatorOutputWithContext.isEmpty())) {
                throw new AssertionError(String.format("No outputs found for %s.", outputSlot));
            }
            multiMap.putSingle(tuple.getField0(), new Tuple(planImplementation, tuple.getField1()));
        }
        return multiMap;
    }

    private static List<MultiMap<Set<InputSlot<?>>, PlanImplementation>> groupImplementationsByInput(Map<InputSlot<?>, PlanEnumeration> map) {
        ArrayList arrayList = new ArrayList(map.size());
        for (Map.Entry<InputSlot<?>, PlanEnumeration> entry : map.entrySet()) {
            InputSlot<?> key = entry.getKey();
            PlanEnumeration value = entry.getValue();
            MultiMap multiMap = new MultiMap();
            for (PlanImplementation planImplementation : value.getPlanImplementations()) {
                multiMap.putSingle(WayangCollections.asSet((Collection) planImplementation.findExecutionOperatorInputs(key)), planImplementation);
            }
            arrayList.add(multiMap);
        }
        return arrayList;
    }

    public void add(PlanImplementation planImplementation) {
        this.planImplementations.add(planImplementation);
        if (!$assertionsDisabled && planImplementation.getTimeEstimate() == null) {
            throw new AssertionError();
        }
        planImplementation.setPlanEnumeration(this);
    }

    private PlanImplementation createSingletonPartialPlan(ExecutionOperator executionOperator, OptimizationContext optimizationContext) {
        return new PlanImplementation(this, new HashMap(0), Collections.singletonList(executionOperator), optimizationContext);
    }

    public void unionInPlace(PlanEnumeration planEnumeration) {
        assertMatchingInterface(this, planEnumeration);
        this.scope.addAll(planEnumeration.scope);
        planEnumeration.planImplementations.forEach(planImplementation -> {
            this.planImplementations.add(planImplementation);
            planImplementation.setPlanEnumeration(this);
        });
        planEnumeration.planImplementations.clear();
    }

    public PlanEnumeration escape(OperatorAlternative.Alternative alternative) {
        if (alternative == null) {
            return this;
        }
        PlanEnumeration planEnumeration = new PlanEnumeration();
        OperatorAlternative operatorAlternative = alternative.getOperatorAlternative();
        planEnumeration.scope.addAll(this.scope);
        planEnumeration.scope.add(operatorAlternative);
        Iterator<InputSlot<?>> it = this.requestedInputSlots.iterator();
        while (it.hasNext()) {
            InputSlot<?> resolveUpstream = alternative.getSlotMapping().resolveUpstream(it.next());
            if (resolveUpstream != null) {
                planEnumeration.requestedInputSlots.add(resolveUpstream);
            }
        }
        for (Tuple<OutputSlot<?>, InputSlot<?>> tuple : this.servingOutputSlots) {
            if (tuple.field1 != null) {
                throw new IllegalStateException("Cannot escape a connected output slot.");
            }
            for (OutputSlot outputSlot : alternative.getSlotMapping().resolveDownstream(tuple.field0.unchecked())) {
                List occupiedSlots = outputSlot.getOccupiedSlots();
                if (occupiedSlots.isEmpty()) {
                    planEnumeration.servingOutputSlots.add(new Tuple<>(outputSlot, null));
                } else {
                    Iterator it2 = occupiedSlots.iterator();
                    while (it2.hasNext()) {
                        planEnumeration.servingOutputSlots.add(new Tuple<>(outputSlot, (InputSlot) it2.next()));
                    }
                }
            }
        }
        Iterator<PlanImplementation> it3 = this.planImplementations.iterator();
        while (it3.hasNext()) {
            planEnumeration.planImplementations.add(it3.next().escape(alternative, planEnumeration));
        }
        return planEnumeration;
    }

    public Collection<PlanImplementation> getPlanImplementations() {
        return this.planImplementations;
    }

    public Set<InputSlot<?>> getRequestedInputSlots() {
        return this.requestedInputSlots;
    }

    public Set<Tuple<OutputSlot<?>, InputSlot<?>>> getServingOutputSlots() {
        return this.servingOutputSlots;
    }

    public Set<OperatorAlternative> getScope() {
        return this.scope;
    }

    public Map<ExecutionOperator, ExecutionTask> getExecutedTasks() {
        return this.executedTasks;
    }

    public String toString() {
        return toIOString();
    }

    private String toIOString() {
        return String.format("%s[%dx, inputs=%s, outputs=%s]", getClass().getSimpleName(), Integer.valueOf(getPlanImplementations().size()), this.requestedInputSlots, this.servingOutputSlots.stream().map((v0) -> {
            return v0.getField0();
        }).distinct().collect(Collectors.toList()));
    }

    private String toScopeString() {
        return String.format("%s[%dx %s]", getClass().getSimpleName(), Integer.valueOf(getPlanImplementations().size()), this.scope);
    }

    static {
        $assertionsDisabled = !PlanEnumeration.class.desiredAssertionStatus();
        logger = LogManager.getLogger(PlanEnumeration.class);
    }
}
