package info.debatty.spark.knngraphs.partitioner;

import com.google.common.primitives.Ints;
import info.debatty.java.graphs.NeighborList;
import info.debatty.java.graphs.Node;
import info.debatty.spark.knngraphs.DistributedGraph;
import info.debatty.spark.knngraphs.partitioner.jabeja.Budget;
import info.debatty.spark.knngraphs.partitioner.jabeja.UnlimitedBudget;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.spark.api.java.JavaPairRDD;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:info/debatty/spark/knngraphs/partitioner/JaBeJa.class */
public class JaBeJa<T> implements Partitioner<T> {
    private static final double T0 = 2.0d;
    private static final double DELTA = 0.003d;
    private static final int SWAPS_PER_ITERATION = 10;
    private static final int ITERATIONS_BEFORE_CHECKPOINT = 20;
    private static final int RDDS_TO_CACHE = 3;
    private static final Logger LOGGER;
    private final int partitions;
    private final Budget budget;
    static final /* synthetic */ boolean $assertionsDisabled;

    public JaBeJa(int i, Budget budget) {
        this.partitions = i;
        this.budget = budget;
    }

    public JaBeJa(int i) {
        this.partitions = i;
        this.budget = new UnlimitedBudget();
    }

    @Override // info.debatty.spark.knngraphs.partitioner.Partitioner
    public final Partitioning<T> partition(JavaPairRDD<Node<T>, NeighborList> javaPairRDD) {
        Partitioning<T> partitioning = new Partitioning<>();
        LinkedList linkedList = new LinkedList();
        partitioning.graph = randomize(javaPairRDD);
        partitioning.graph = DistributedGraph.moveNodes(partitioning.graph, this.partitions);
        partitioning.graph.cache();
        partitioning.graph.count();
        double d = 2.0d;
        int i = 0;
        while (true) {
            i++;
            LOGGER.info("Tr = {}", Double.valueOf(d));
            SwapResult<T> swap = swap(partitioning.graph, d, 10);
            LOGGER.info("Performed {} swaps", Integer.valueOf(swap.swaps));
            partitioning.graph = swap.graph;
            partitioning.graph.cache();
            if (i % ITERATIONS_BEFORE_CHECKPOINT == 0) {
                LOGGER.info("Checkpoint");
                partitioning.graph.rdd().localCheckpoint();
            }
            partitioning.graph.count();
            linkedList.add(partitioning.graph);
            if (i > RDDS_TO_CACHE) {
                ((JavaPairRDD) linkedList.pop()).unpersist();
            }
            if (swap.swaps > 0 && !this.budget.isExhausted(partitioning)) {
                d = Math.max(1.0d, d - DELTA);
            }
        }
        partitioning.graph = DistributedGraph.moveNodes(partitioning.graph, this.partitions);
        partitioning.graph.cache();
        partitioning.graph.count();
        partitioning.end_time = System.currentTimeMillis();
        return partitioning;
    }

    public static final <U> int countCrossEdges(JavaPairRDD<Node<U>, NeighborList> javaPairRDD, int i) {
        int[] buildColorIndex = buildColorIndex(javaPairRDD);
        return countCrossEdges(buildColorIndex, buildDegreesIndex(javaPairRDD, buildColorIndex, i));
    }

    public static final <U> double computeBalance(JavaPairRDD<Node<U>, NeighborList> javaPairRDD, int i) {
        return computeBalance(buildColorIndex(javaPairRDD), i);
    }

    final JavaPairRDD<Node<T>, NeighborList> randomize(JavaPairRDD<Node<T>, NeighborList> javaPairRDD) {
        return javaPairRDD.mapToPair(new RandomizeFunction(this.partitions));
    }

    public static final <U> int[] buildColorIndex(JavaPairRDD<Node<U>, NeighborList> javaPairRDD) {
        Map collectAsMap = javaPairRDD.mapToPair(new GetPartitionFunction()).collectAsMap();
        int[] iArr = new int[collectAsMap.size()];
        for (Map.Entry entry : collectAsMap.entrySet()) {
            iArr[((Integer) entry.getKey()).intValue()] = ((Integer) entry.getValue()).intValue();
        }
        return iArr;
    }

    public static final <U> double computeBalance(int[] iArr, int i) {
        int[] iArr2 = new int[i];
        for (int i2 : iArr) {
            iArr2[i2] = iArr2[i2] + 1;
        }
        LOGGER.info("Sizes: {}", iArr2);
        return ((1.0d * Ints.max(iArr2)) / sum(iArr2)) * i;
    }

    private static int sum(int[] iArr) {
        int i = 0;
        for (int i2 : iArr) {
            i += i2;
        }
        return i;
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [int[], int[][]] */
    static final <U> int[][] buildDegreesIndex(JavaPairRDD<Node<U>, NeighborList> javaPairRDD, int[] iArr, int i) {
        Map collectAsMap = javaPairRDD.mapToPair(new GetDegreesFunction(i, iArr)).collectAsMap();
        ?? r0 = new int[collectAsMap.size()];
        for (Map.Entry entry : collectAsMap.entrySet()) {
            r0[((Integer) entry.getKey()).intValue()] = (int[]) entry.getValue();
        }
        return r0;
    }

    final SwapResult<T> swap(JavaPairRDD<Node<T>, NeighborList> javaPairRDD, double d, int i) {
        int[] buildColorIndex = buildColorIndex(javaPairRDD);
        int[][] buildDegreesIndex = buildDegreesIndex(javaPairRDD, buildColorIndex, this.partitions);
        LOGGER.info("Cross edges: {}", Integer.valueOf(countCrossEdges(buildColorIndex, buildDegreesIndex)));
        LOGGER.info("Imbalance: {}", Double.valueOf(computeBalance(buildColorIndex, this.partitions)));
        List collect = javaPairRDD.mapPartitions(new MakeRequestsFunction(buildColorIndex, buildDegreesIndex, d, i)).collect();
        LOGGER.info("Swaps: {}", Integer.valueOf(collect.size()));
        LOGGER.debug("{}", collect);
        List collect2 = javaPairRDD.mapPartitions(new ProcessRequestsFunction(collect)).collect();
        LOGGER.debug("{}", collect2);
        if ($assertionsDisabled || collect2.size() <= collect.size()) {
            return new SwapResult<>(javaPairRDD.mapPartitionsToPair(new PerformSwapFunction(collect2)), collect2.size());
        }
        throw new AssertionError();
    }

    private static int countCrossEdges(int[] iArr, int[][] iArr2) {
        int length = iArr2[0].length;
        int i = 0;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int i3 = iArr[i2];
            for (int i4 = 0; i4 < length; i4++) {
                if (i4 != i3) {
                    i += iArr2[i2][i4];
                }
            }
        }
        return i;
    }

    static {
        $assertionsDisabled = !JaBeJa.class.desiredAssertionStatus();
        LOGGER = LoggerFactory.getLogger(JaBeJa.class);
    }
}
