/*
 * Decompiled with CFR 0.152.
 */
package io.activej.dataflow.dataset.impl;

import io.activej.dataflow.dataset.Dataset;
import io.activej.dataflow.dataset.DatasetUtils;
import io.activej.dataflow.dataset.SortedDataset;
import io.activej.dataflow.graph.DataflowContext;
import io.activej.dataflow.graph.DataflowGraph;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.node.NodeJoin;
import io.activej.datastream.processor.StreamJoin;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;

public final class DatasetJoin<K, L, R, V>
extends SortedDataset<K, V> {
    private final SortedDataset<K, L> left;
    private final SortedDataset<K, R> right;
    private final StreamJoin.Joiner<K, L, R, V> joiner;
    private final int sharderNonce = ThreadLocalRandom.current().nextInt();

    public DatasetJoin(SortedDataset<K, L> left, SortedDataset<K, R> right, StreamJoin.Joiner<K, L, R, V> joiner, Class<V> resultType, Function<V, K> keyFunction) {
        super(resultType, left.keyComparator(), left.keyType(), keyFunction);
        this.left = left;
        this.right = right;
        this.joiner = joiner;
    }

    @Override
    public List<StreamId> channels(DataflowContext context) {
        DataflowGraph graph = context.getGraph();
        ArrayList<StreamId> outputStreamIds = new ArrayList<StreamId>();
        DataflowContext next = context.withFixedNonce(this.sharderNonce);
        List<StreamId> leftStreamIds = this.left.channels(next);
        List<StreamId> rightStreamIds = DatasetUtils.repartitionAndSort(next, this.right, graph.getPartitions(leftStreamIds));
        assert (leftStreamIds.size() == rightStreamIds.size());
        int index = context.generateNodeIndex();
        for (int i = 0; i < leftStreamIds.size(); ++i) {
            StreamId leftStreamId = leftStreamIds.get(i);
            StreamId rightStreamId = rightStreamIds.get(i);
            NodeJoin node = new NodeJoin(index, leftStreamId, rightStreamId, this.left.keyComparator(), this.left.keyFunction(), this.right.keyFunction(), this.joiner);
            graph.addNode(graph.getPartition(leftStreamId), node);
            outputStreamIds.add(node.getOutput());
        }
        return outputStreamIds;
    }

    @Override
    public Collection<Dataset<?>> getBases() {
        return Arrays.asList(this.left, this.right);
    }
}

