package io.activej.dataflow.graph;

import io.activej.async.process.AsyncCloseable;
import io.activej.codec.StructuredCodec;
import io.activej.codec.StructuredCodecs;
import io.activej.codec.json.JsonUtils;
import io.activej.common.ref.RefInt;
import io.activej.dataflow.DataflowClient;
import io.activej.dataflow.node.Node;
import io.activej.dataflow.node.NodeDownload;
import io.activej.dataflow.node.NodeUpload;
import io.activej.promise.Promise;
import io.activej.promise.Promises;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;

/* loaded from: input_file:io/activej/dataflow/graph/DataflowGraph.class */
public final class DataflowGraph {
    private final Map<Node, Partition> nodePartitions = new LinkedHashMap();
    private final Map<StreamId, Node> streams = new LinkedHashMap();
    private final DataflowClient client;
    private final List<Partition> availablePartitions;
    private final StructuredCodec<List<Node>> listNodeCodec;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/activej/dataflow/graph/DataflowGraph$PartitionSession.class */
    public static class PartitionSession implements AsyncCloseable {
        private final Partition partition;
        private final DataflowClient.Session session;

        private PartitionSession(Partition partition, DataflowClient.Session session) {
            this.partition = partition;
            this.session = session;
        }

        public Promise<Void> execute(List<Node> list) {
            return this.session.execute(list);
        }

        public void closeEx(@NotNull Throwable th) {
            this.session.closeEx(th);
        }
    }

    public DataflowGraph(DataflowClient dataflowClient, List<Partition> list, StructuredCodec<Node> structuredCodec) {
        this.client = dataflowClient;
        this.availablePartitions = list;
        this.listNodeCodec = StructuredCodecs.ofList(structuredCodec);
    }

    public List<Partition> getAvailablePartitions() {
        return this.availablePartitions;
    }

    public Partition getPartition(Node node) {
        return this.nodePartitions.get(node);
    }

    public Partition getPartition(StreamId streamId) {
        return getPartition(this.streams.get(streamId));
    }

    private Map<Partition, List<Node>> getNodesByPartition() {
        return (Map) this.nodePartitions.entrySet().stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getValue();
        }, Collectors.mapping((v0) -> {
            return v0.getKey();
        }, Collectors.toList())));
    }

    public Promise<Void> execute() {
        Map<Partition, List<Node>> nodesByPartition = getNodesByPartition();
        return connect(nodesByPartition.keySet()).then(list -> {
            return Promises.all(list.stream().map(partitionSession -> {
                return partitionSession.execute((List) nodesByPartition.get(partitionSession.partition));
            })).whenException(() -> {
                list.forEach((v0) -> {
                    v0.close();
                });
            });
        });
    }

    private Promise<List<PartitionSession>> connect(Set<Partition> set) {
        return Promises.toList(set.stream().map(partition -> {
            return this.client.connect(partition.getAddress()).map(session -> {
                return new PartitionSession(partition, session);
            }).toTry();
        })).then(list -> {
            List list = (List) list.stream().filter((v0) -> {
                return v0.isSuccess();
            }).map((v0) -> {
                return v0.get();
            }).collect(Collectors.toList());
            if (list.size() == set.size()) {
                return Promise.of(list);
            }
            list.forEach((v0) -> {
                v0.close();
            });
            return Promise.ofException(new Exception("Can't connect to all partitions"));
        });
    }

    public void addNode(Partition partition, Node node) {
        this.nodePartitions.put(node, partition);
        Iterator<StreamId> it = node.getOutputs().iterator();
        while (it.hasNext()) {
            this.streams.put(it.next(), node);
        }
    }

    public void addNodeStream(Node node, StreamId streamId) {
        this.streams.put(streamId, node);
    }

    public List<Partition> getPartitions(List<? extends StreamId> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<? extends StreamId> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(getPartition(it.next()));
        }
        return arrayList;
    }

    public String toGraphViz() {
        return toGraphViz(false, 2);
    }

    public String toGraphViz(boolean z) {
        return toGraphViz(z, 2);
    }

    public String toGraphViz(int i) {
        return toGraphViz(false, i);
    }

    public String toGraphViz(boolean z, int i) {
        StringBuilder sb = new StringBuilder("digraph {\n\n");
        RefInt refInt = new RefInt(0);
        RefInt refInt2 = new RefInt(0);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        ArrayList<NodeUpload> arrayList = new ArrayList();
        for (Node node : this.nodePartitions.keySet()) {
            if (node instanceof NodeDownload) {
                NodeDownload nodeDownload = (NodeDownload) node;
                hashMap2.put(nodeDownload.getStreamId(), nodeDownload.getOutput());
            } else if (node instanceof NodeUpload) {
                arrayList.add((NodeUpload) node);
            } else {
                node.getInputs().forEach(streamId -> {
                    hashMap.put(streamId, node);
                });
            }
        }
        for (NodeUpload nodeUpload : arrayList) {
            StreamId streamId2 = nodeUpload.getStreamId();
            if (!hashMap2.containsKey(streamId2)) {
                hashMap.put(streamId2, nodeUpload);
            }
        }
        HashMap hashMap3 = new HashMap();
        getNodesByPartition().entrySet().stream().limit(i).forEach(entry -> {
            StringBuilder append = sb.append("  subgraph cluster_");
            int i2 = refInt2.value + 1;
            refInt2.value = i2;
            append.append(i2).append(" {\n").append("    label=\"").append(((Partition) entry.getKey()).getAddress()).append("\";\n    style=rounded;\n\n");
            for (Node node2 : (List) entry.getValue()) {
                if (!(node2 instanceof NodeDownload) && (!(node2 instanceof NodeUpload) || !hashMap2.containsKey(((NodeUpload) node2).getStreamId()))) {
                    StringBuilder append2 = new StringBuilder().append("n");
                    int i3 = refInt.value + 1;
                    refInt.value = i3;
                    String sb2 = append2.append(i3).toString();
                    sb.append("    ").append(sb2).append(" [label=\"").append(node2.getClass().getSimpleName()).append("\"];\n");
                    hashMap3.put(node2, sb2);
                }
            }
            sb.append("  }\n\n");
        });
        HashSet hashSet = new HashSet();
        hashMap3.forEach((node2, str) -> {
            StreamId streamId3;
            for (StreamId streamId4 : node2.getOutputs()) {
                Node node2 = (Node) hashMap.get(streamId4);
                boolean z2 = false;
                boolean z3 = false;
                StreamId streamId5 = null;
                if (node2 == null && (streamId3 = (StreamId) hashMap2.get(streamId4)) != null) {
                    streamId5 = streamId4;
                    streamId4 = streamId3;
                    node2 = (Node) hashMap.get(streamId4);
                    z2 = node2 != null;
                }
                String str = (String) hashMap3.get(node2);
                if (str == null && !z2) {
                    str = "s" + streamId4.getId();
                    hashSet.add(str);
                    z3 = true;
                }
                if (str != null) {
                    sb.append("  ").append(str).append(" -> ").append(str).append(" [");
                    if (z || z3) {
                        if (streamId5 != null) {
                            sb.append("taillabel=\"").append(streamId5).append("\", headlabel=\"").append(streamId4).append("\"");
                        } else {
                            sb.append("xlabel=\"").append(streamId4).append("\"");
                        }
                        if (z2) {
                            sb.append(", ");
                        }
                    }
                    if (z2) {
                        sb.append("style=dashed");
                    }
                    sb.append("];\n");
                }
            }
        });
        if (!hashSet.isEmpty()) {
            sb.append('\n');
            hashSet.forEach(str2 -> {
                sb.append("  ").append(str2).append(" [shape=point];\n");
            });
        }
        sb.append("}");
        return sb.toString();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (Map.Entry<Partition, List<Node>> entry : getNodesByPartition().entrySet()) {
            sb.append("--- ").append(entry.getKey()).append("\n\n");
            sb.append(JsonUtils.toJson(this.listNodeCodec, entry.getValue()));
            sb.append("\n\n");
        }
        return sb.toString();
    }
}
