package org.apache.wayang.core.plan.wayangplan;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.wayang.core.util.WayangCollections;

/* loaded from: input_file:org/apache/wayang/core/plan/wayangplan/SlotMapping.class */
public class SlotMapping {
    private final Logger logger = LogManager.getLogger(getClass());
    private final Map<Slot<?>, Slot<?>> upstreamMapping = new HashMap();
    private Map<Slot<?>, Collection> downstreamMapping = null;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static SlotMapping createIdentityMapping(Operator operator) {
        return wrap(operator, operator);
    }

    public static SlotMapping wrap(Operator operator, Operator operator2) {
        SlotMapping slotMapping = new SlotMapping();
        slotMapping.mapAllUpsteam(operator2.getAllOutputs(), operator.getAllOutputs());
        slotMapping.mapAllUpsteam(operator.getAllInputs(), operator2.getAllInputs());
        return slotMapping;
    }

    public void mapAllUpsteam(InputSlot<?>[] inputSlotArr, InputSlot<?>[] inputSlotArr2) {
        if (inputSlotArr.length != inputSlotArr2.length) {
            throw new IllegalArgumentException(String.format("Incompatible number of input slots between %s and %s.", Arrays.toString(inputSlotArr), Arrays.toString(inputSlotArr2)));
        }
        for (int i = 0; i < inputSlotArr.length; i++) {
            mapUpstream(inputSlotArr[i], inputSlotArr2[i]);
        }
    }

    public void mapAllUpsteam(OutputSlot<?>[] outputSlotArr, OutputSlot<?>[] outputSlotArr2) {
        if (outputSlotArr.length != outputSlotArr2.length) {
            throw new IllegalArgumentException();
        }
        for (int i = 0; i < outputSlotArr.length; i++) {
            mapUpstream(outputSlotArr[i], outputSlotArr2[i]);
        }
    }

    public void mapUpstream(InputSlot<?> inputSlot, InputSlot<?> inputSlot2) {
        if (inputSlot2 == null) {
            this.upstreamMapping.remove(inputSlot);
            this.downstreamMapping = null;
        } else {
            if (!inputSlot.isCompatibleWith(inputSlot2)) {
                throw new IllegalArgumentException(String.format("Incompatible slots given: %s -> %s", inputSlot, inputSlot2));
            }
            this.upstreamMapping.put(inputSlot, inputSlot2);
            this.downstreamMapping = null;
        }
    }

    public void mapUpstream(OutputSlot<?> outputSlot, OutputSlot<?> outputSlot2) {
        if (outputSlot2 == null) {
            this.upstreamMapping.remove(outputSlot);
            this.downstreamMapping = null;
        } else {
            if (!outputSlot.isCompatibleWith(outputSlot2)) {
                throw new IllegalArgumentException(String.format("Incompatible slots given: %s -> %s", outputSlot, outputSlot2));
            }
            this.upstreamMapping.put(outputSlot, outputSlot2);
            this.downstreamMapping = null;
        }
    }

    public <T> InputSlot<T> resolveUpstream(InputSlot<T> inputSlot) {
        if (inputSlot.getOccupant() != null) {
            this.logger.warn("Trying to resolve (upstream) an InputSlot with an occupant.");
        }
        return (InputSlot) this.upstreamMapping.get(inputSlot);
    }

    public <T> OutputSlot<T> resolveUpstream(OutputSlot<T> outputSlot) {
        return (OutputSlot) this.upstreamMapping.get(outputSlot);
    }

    public <T> Collection<InputSlot<T>> resolveDownstream(InputSlot<T> inputSlot) {
        return getOrCreateDownstreamMapping().getOrDefault(inputSlot, Collections.emptyList());
    }

    public <T> Collection<OutputSlot<T>> resolveDownstream(OutputSlot<T> outputSlot) {
        if (!outputSlot.getOccupiedSlots().isEmpty()) {
            this.logger.warn("Trying to resolve (downstream) an OutputSlot with occupiers.");
        }
        return getOrCreateDownstreamMapping().getOrDefault(outputSlot, Collections.emptyList());
    }

    private Map<Slot<?>, Collection> getOrCreateDownstreamMapping() {
        if (this.downstreamMapping == null) {
            HashMap hashMap = new HashMap();
            for (Map.Entry<Slot<?>, Slot<?>> entry : this.upstreamMapping.entrySet()) {
                ((Collection) hashMap.computeIfAbsent(entry.getValue(), slot -> {
                    return new LinkedList();
                })).add(entry.getKey());
            }
            this.downstreamMapping = hashMap;
        }
        return this.downstreamMapping;
    }

    public void replaceInputSlotMappings(Operator operator, Operator operator2) {
        InputSlot<?> resolveUpstream;
        if (operator.getParent() != operator2) {
            this.logger.warn("Using bare indices to replace {} (parent {}) with {}.", operator, operator.getParent(), operator2);
            if (!$assertionsDisabled && operator.getNumInputs() != operator2.getNumInputs()) {
                throw new AssertionError(String.format("Operators %s and %s are not matching.", operator, operator2));
            }
            for (int i = 0; i < operator.getNumInputs(); i++) {
                InputSlot<?> input = operator.getInput(i);
                InputSlot<?> input2 = operator2.getInput(i);
                InputSlot<?> resolveUpstream2 = resolveUpstream(input);
                if (resolveUpstream2 != null) {
                    mapUpstream(input2, resolveUpstream2);
                    delete(input);
                }
            }
            return;
        }
        SlotMapping slotMapping = operator.getContainer().getSlotMapping();
        for (int i2 = 0; i2 < operator.getNumInputs(); i2++) {
            InputSlot<?> input3 = operator.getInput(i2);
            if (!$assertionsDisabled && input3 == null) {
                throw new AssertionError(String.format("No %dth input for %s (for %s).", Integer.valueOf(i2), operator, operator2));
            }
            if (input3.getOccupant() == null && (resolveUpstream = resolveUpstream(input3)) != null) {
                delete(input3);
                InputSlot<?> resolveUpstream3 = slotMapping.resolveUpstream(input3);
                if (resolveUpstream3 != null) {
                    mapUpstream(resolveUpstream3, resolveUpstream);
                }
            }
        }
    }

    private void delete(InputSlot<?> inputSlot) {
        this.upstreamMapping.remove(inputSlot);
        this.downstreamMapping = null;
    }

    private void delete(OutputSlot<?> outputSlot) {
        this.upstreamMapping.remove(outputSlot);
        this.downstreamMapping = null;
    }

    public void replaceOutputSlotMappings(Operator operator, Operator operator2) {
        if (operator.getParent() != operator2) {
            this.logger.warn("Using bare indices to replace {} (parent {}) with {}.", operator, operator.getParent(), operator2);
            if (!$assertionsDisabled && operator.getNumOutputs() != operator2.getNumOutputs()) {
                throw new AssertionError(String.format("Operators %s and %s are not matching.", operator, operator2));
            }
            for (int i = 0; i < operator.getNumOutputs(); i++) {
                OutputSlot<?> output = operator.getOutput(i);
                OutputSlot<?> output2 = operator2.getOutput(i);
                this.upstreamMapping.entrySet().stream().filter(entry -> {
                    return entry.getValue() == output;
                }).findFirst().map((v0) -> {
                    return v0.getKey();
                }).ifPresent(slot -> {
                    mapUpstream((OutputSlot<?>) slot, (OutputSlot<?>) output2);
                });
            }
            return;
        }
        Map<Slot<?>, Collection> orCreateDownstreamMapping = getOrCreateDownstreamMapping();
        SlotMapping slotMapping = operator.getContainer().getSlotMapping();
        for (int i2 = 0; i2 < operator.getNumOutputs(); i2++) {
            OutputSlot<?> output3 = operator.getOutput(i2);
            Collection collection = orCreateDownstreamMapping.get(output3);
            if (collection != null) {
                OutputSlot<?> outputSlot = (OutputSlot) WayangCollections.getSingleOrNull(slotMapping.getOrCreateDownstreamMapping().getOrDefault(output3, Collections.emptySet()));
                Iterator it = collection.iterator();
                while (it.hasNext()) {
                    mapUpstream((OutputSlot<?>) it.next(), outputSlot);
                }
            }
        }
    }

    public Map<InputSlot, OutputSlot> compose(SlotMapping slotMapping) {
        HashMap hashMap = new HashMap(2);
        slotMapping.upstreamMapping.entrySet().stream().filter(entry -> {
            return ((Slot) entry.getKey()).isInputSlot();
        }).forEach(entry2 -> {
            OutputSlot resolveUpstream;
            OutputSlot occupant = ((InputSlot) entry2.getValue()).getOccupant();
            if (occupant == null || (resolveUpstream = resolveUpstream(occupant)) == null) {
                return;
            }
            hashMap.put((InputSlot) entry2.getKey(), resolveUpstream);
        });
        return hashMap;
    }

    public Map<Slot<?>, Slot<?>> getUpstreamMapping() {
        return this.upstreamMapping;
    }

    static {
        $assertionsDisabled = !SlotMapping.class.desiredAssertionStatus();
    }
}
