/*
 * Decompiled with CFR 0.152.
 */
package io.activej.dataflow.dataset.impl;

import io.activej.dataflow.dataset.Dataset;
import io.activej.dataflow.dataset.DatasetUtils;
import io.activej.dataflow.graph.DataflowContext;
import io.activej.dataflow.graph.DataflowGraph;
import io.activej.dataflow.graph.Partition;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.node.NodeShard;
import io.activej.dataflow.node.NodeUnion;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import org.jetbrains.annotations.Nullable;

public final class DatasetRepartition<T, K>
extends Dataset<T> {
    private final Dataset<T> input;
    private final Function<T, K> keyFunction;
    @Nullable
    private final List<Partition> partitions;

    public DatasetRepartition(Dataset<T> input, Function<T, K> keyFunction, @Nullable List<Partition> partitions) {
        super(input.valueType());
        this.input = input;
        this.keyFunction = keyFunction;
        this.partitions = partitions;
    }

    @Override
    public List<StreamId> channels(DataflowContext context) {
        DataflowGraph graph = context.getGraph();
        List<Partition> partitions = this.partitions == null ? graph.getAvailablePartitions() : this.partitions;
        int nonce = context.getNonce();
        ArrayList<StreamId> outputStreamIds = new ArrayList<StreamId>();
        ArrayList<NodeShard<K, T>> sharders = new ArrayList<NodeShard<K, T>>();
        int shardIndex = context.generateNodeIndex();
        for (StreamId inputStreamId : this.input.channels(context.withoutFixedNonce())) {
            Partition partition = graph.getPartition(inputStreamId);
            NodeShard<K, T> sharder = new NodeShard<K, T>(shardIndex, this.keyFunction, inputStreamId, nonce);
            graph.addNode(partition, sharder);
            sharders.add(sharder);
        }
        int unionIndex = context.generateNodeIndex();
        int[] downloadIndexes = DatasetUtils.generateIndexes(context, sharders.size());
        int[] uploadIndexes = DatasetUtils.generateIndexes(context, partitions.size());
        for (int i = 0; i < partitions.size(); ++i) {
            Partition partition = partitions.get(i);
            ArrayList<StreamId> unionInputs = new ArrayList<StreamId>();
            for (int j = 0; j < sharders.size(); ++j) {
                NodeShard sharder = (NodeShard)sharders.get(j);
                StreamId sharderOutput = sharder.newPartition();
                graph.addNodeStream(sharder, sharderOutput);
                StreamId unionInput = DatasetUtils.forwardChannel(context, this.input.valueType(), sharderOutput, partition, uploadIndexes[i], downloadIndexes[j]);
                unionInputs.add(unionInput);
            }
            NodeUnion nodeUnion = new NodeUnion(unionIndex, unionInputs);
            graph.addNode(partition, nodeUnion);
            outputStreamIds.add(nodeUnion.getOutput());
        }
        return outputStreamIds;
    }
}

