/*
 * Decompiled with CFR 0.152.
 */
package io.activej.dataflow.graph;

import io.activej.async.process.AsyncCloseable;
import io.activej.common.collection.Try;
import io.activej.common.ref.RefInt;
import io.activej.dataflow.DataflowClient;
import io.activej.dataflow.DataflowException;
import io.activej.dataflow.graph.Partition;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.json.JsonCodec;
import io.activej.dataflow.json.JsonUtils;
import io.activej.dataflow.node.Node;
import io.activej.dataflow.node.NodeDownload;
import io.activej.dataflow.node.NodeUpload;
import io.activej.promise.Promise;
import io.activej.promise.Promises;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;

public final class DataflowGraph {
    private final Map<Node, Partition> nodePartitions = new LinkedHashMap<Node, Partition>();
    private final Map<StreamId, Node> streams = new LinkedHashMap<StreamId, Node>();
    private final DataflowClient client;
    private final List<Partition> availablePartitions;
    private final JsonCodec<List<Node>> listNodeCodec;

    public DataflowGraph(DataflowClient client, List<Partition> availablePartitions, JsonCodec<List<Node>> listNodeCodec) {
        this.client = client;
        this.availablePartitions = availablePartitions;
        this.listNodeCodec = listNodeCodec;
    }

    public List<Partition> getAvailablePartitions() {
        return this.availablePartitions;
    }

    public Partition getPartition(Node node) {
        return this.nodePartitions.get(node);
    }

    public Partition getPartition(StreamId streamId) {
        return this.getPartition(this.streams.get(streamId));
    }

    private Map<Partition, List<Node>> getNodesByPartition() {
        return this.nodePartitions.entrySet().stream().collect(Collectors.groupingBy(Map.Entry::getValue, Collectors.mapping(Map.Entry::getKey, Collectors.toList())));
    }

    public Promise<Void> execute() {
        Map<Partition, List<Node>> nodesByPartition = this.getNodesByPartition();
        long taskId = ThreadLocalRandom.current().nextInt() & 0x3FFFFFFF;
        return this.connect(nodesByPartition.keySet()).then(sessions -> Promises.all(sessions.stream().map(session -> session.execute(taskId, (List)nodesByPartition.get(((PartitionSession)session).partition)))).whenException(() -> sessions.forEach(AsyncCloseable::close)));
    }

    private Promise<List<PartitionSession>> connect(Set<Partition> partitions) {
        return Promises.toList(partitions.stream().map(partition -> this.client.connect(partition.getAddress()).map(session -> new PartitionSession((Partition)partition, (DataflowClient.Session)session)).toTry())).map(tries -> {
            List<PartitionSession> sessions = tries.stream().filter(Try::isSuccess).map(Try::get).collect(Collectors.toList());
            if (sessions.size() != partitions.size()) {
                sessions.forEach(AsyncCloseable::close);
                throw new DataflowException("Cannot connect to all partitions");
            }
            return sessions;
        });
    }

    public void addNode(Partition partition, Node node) {
        this.nodePartitions.put(node, partition);
        for (StreamId streamId : node.getOutputs()) {
            this.streams.put(streamId, node);
        }
    }

    public void addNodeStream(Node node, StreamId streamId) {
        this.streams.put(streamId, node);
    }

    public List<Partition> getPartitions(List<? extends StreamId> channels) {
        ArrayList<Partition> partitions = new ArrayList<Partition>();
        for (StreamId streamId : channels) {
            Partition partition = this.getPartition(streamId);
            partitions.add(partition);
        }
        return partitions;
    }

    public String toGraphViz() {
        return this.toGraphViz(false, 2);
    }

    public String toGraphViz(boolean streamLabels) {
        return this.toGraphViz(streamLabels, 2);
    }

    public String toGraphViz(int maxPartitions) {
        return this.toGraphViz(false, maxPartitions);
    }

    public String toGraphViz(boolean streamLabels, int maxPartitions) {
        StringBuilder sb = new StringBuilder("digraph {\n\n");
        RefInt nodeCounter = new RefInt(0);
        RefInt clusterCounter = new RefInt(0);
        HashMap<StreamId, NodeUpload> nodesByInput = new HashMap<StreamId, NodeUpload>();
        HashMap<StreamId, StreamId> network = new HashMap<StreamId, StreamId>();
        ArrayList<NodeUpload> uploads = new ArrayList<NodeUpload>();
        for (Node node2 : this.nodePartitions.keySet()) {
            if (node2 instanceof NodeDownload) {
                NodeDownload download = (NodeDownload)node2;
                network.put(download.getStreamId(), download.getOutput());
                continue;
            }
            if (node2 instanceof NodeUpload) {
                uploads.add((NodeUpload)node2);
                continue;
            }
            node2.getInputs().forEach(input -> nodesByInput.put((StreamId)input, (NodeUpload)node2));
        }
        for (NodeUpload upload : uploads) {
            StreamId streamId = upload.getStreamId();
            if (network.containsKey(streamId)) continue;
            nodesByInput.put(streamId, upload);
        }
        HashMap<Node, String> nodeIds = new HashMap<Node, String>();
        this.getNodesByPartition().entrySet().stream().limit(maxPartitions).forEach(e -> {
            sb.append("  subgraph cluster_").append(++clusterCounter.value).append(" {\n").append("    label=\"").append(((Partition)e.getKey()).getAddress()).append("\";\n    style=rounded;\n\n");
            for (Node node : (List)e.getValue()) {
                if (node instanceof NodeDownload || node instanceof NodeUpload && network.containsKey(((NodeUpload)node).getStreamId())) continue;
                String nodeId = "n" + ++nodeCounter.value;
                String name = node.getClass().getSimpleName();
                sb.append("    ").append(nodeId).append(" [label=\"").append(name.startsWith("Node") ? name.substring(4) : name).append("\"];\n");
                nodeIds.put(node, nodeId);
            }
            sb.append("  }\n\n");
        });
        HashSet notFound = new HashSet();
        nodeIds.forEach((node, id) -> {
            for (StreamId output : node.getOutputs()) {
                String nodeId;
                StreamId through;
                Node outputNode = (Node)nodesByInput.get(output);
                boolean net = false;
                boolean forceLabel = false;
                StreamId prev = null;
                if (outputNode == null && (through = (StreamId)network.get(output)) != null) {
                    prev = output;
                    output = through;
                    outputNode = (Node)nodesByInput.get(output);
                    boolean bl = net = outputNode != null;
                }
                if ((nodeId = (String)nodeIds.get(outputNode)) == null && !net) {
                    nodeId = "s" + output.getId();
                    notFound.add(nodeId);
                    forceLabel = true;
                }
                if (nodeId == null) continue;
                sb.append("  ").append((String)id).append(" -> ").append(nodeId).append(" [");
                if (streamLabels || forceLabel) {
                    if (prev != null) {
                        sb.append("taillabel=\"").append(prev).append("\", headlabel=\"").append(output).append("\"");
                    } else {
                        sb.append("xlabel=\"").append(output).append("\"");
                    }
                    if (net) {
                        sb.append(", ");
                    }
                }
                if (net) {
                    sb.append("style=dashed");
                }
                sb.append("];\n");
            }
        });
        if (!notFound.isEmpty()) {
            sb.append('\n');
            notFound.forEach(id -> sb.append("  " + id + " [shape=point];\n"));
        }
        sb.append("}");
        return sb.toString();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        Map<Partition, List<Node>> map = this.getNodesByPartition();
        for (Map.Entry<Partition, List<Node>> entry : map.entrySet()) {
            sb.append("--- " + entry.getKey() + "\n\n");
            sb.append(JsonUtils.toJson(this.listNodeCodec, entry.getValue()));
            sb.append("\n\n");
        }
        return sb.toString();
    }

    private static class PartitionSession
    implements AsyncCloseable {
        private final Partition partition;
        private final DataflowClient.Session session;

        private PartitionSession(Partition partition, DataflowClient.Session session) {
            this.partition = partition;
            this.session = session;
        }

        public Promise<Void> execute(long taskId, List<Node> nodes) {
            return this.session.execute(taskId, nodes);
        }

        public void closeEx(@NotNull Exception e) {
            this.session.closeEx(e);
        }
    }
}

