package info.debatty.spark.knngraphs;

import info.debatty.java.graphs.Dijkstra;
import info.debatty.java.graphs.Graph;
import info.debatty.java.graphs.NeighborList;
import info.debatty.java.graphs.Node;
import info.debatty.java.graphs.SimilarityInterface;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;

/* loaded from: input_file:info/debatty/spark/knngraphs/BalancedKMedoidsPartitioner.class */
public class BalancedKMedoidsPartitioner<T> implements Serializable {
    public SimilarityInterface<T> similarity;
    public int iterations = 5;
    public int partitions = 4;
    public double imbalance = 1.1d;
    List<Node<T>> medoids;
    NodePartitioner internal_partitioner;
    public static final String PARTITION_KEY = "BKMP_PARTITION_ID";

    /* loaded from: input_file:info/debatty/spark/knngraphs/BalancedKMedoidsPartitioner$AssignFunction.class */
    private class AssignFunction implements PairFlatMapFunction<Iterator<Tuple2<Node<T>, NeighborList>>, Node<T>, NeighborList> {
        private final List<Node<T>> medoids;

        public AssignFunction(List<Node<T>> list) {
            this.medoids = list;
        }

        public Iterable<Tuple2<Node<T>, NeighborList>> call(Iterator<Tuple2<Node<T>, NeighborList>> it) throws Exception {
            ArrayList arrayList = new ArrayList();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
            double size = (BalancedKMedoidsPartitioner.this.imbalance * arrayList.size()) / BalancedKMedoidsPartitioner.this.partitions;
            int[] iArr = new int[BalancedKMedoidsPartitioner.this.partitions];
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                Tuple2 tuple2 = (Tuple2) it2.next();
                double[] dArr = new double[BalancedKMedoidsPartitioner.this.partitions];
                double[] dArr2 = new double[BalancedKMedoidsPartitioner.this.partitions];
                for (int i = 0; i < BalancedKMedoidsPartitioner.this.partitions; i++) {
                    dArr[i] = BalancedKMedoidsPartitioner.this.similarity.similarity(this.medoids.get(i).value, ((Node) tuple2._1).value);
                }
                for (int i2 = 0; i2 < BalancedKMedoidsPartitioner.this.partitions; i2++) {
                    dArr2[i2] = dArr[i2] * (1.0d - (iArr[i2] / size));
                }
                int argmax = BalancedKMedoidsPartitioner.argmax(dArr2);
                iArr[argmax] = iArr[argmax] + 1;
                ((Node) tuple2._1).setAttribute(BalancedKMedoidsPartitioner.PARTITION_KEY, Integer.valueOf(argmax));
            }
            return arrayList;
        }
    }

    public JavaRDD<Graph<T>> partition(JavaPairRDD<Node<T>, NeighborList> javaPairRDD) {
        javaPairRDD.cache();
        this.internal_partitioner = new NodePartitioner(this.partitions);
        Iterator it = javaPairRDD.sample(false, (10.0d * this.partitions) / javaPairRDD.count()).collect().iterator();
        this.medoids = new ArrayList();
        for (int i = 0; i < this.partitions; i++) {
            this.medoids.add(((Tuple2) it.next())._1);
        }
        for (int i2 = 0; i2 < this.iterations; i2++) {
            computeNewMedoids(javaPairRDD.mapPartitionsToPair(new AssignFunction(this.medoids), true).partitionBy(this.internal_partitioner));
        }
        return javaPairRDD.mapPartitionsToPair(new AssignFunction(this.medoids), true).partitionBy(this.internal_partitioner).mapPartitions(new NeighborListToGraph(this.similarity));
    }

    private static long sum(long[] jArr) {
        long j = 0;
        for (long j2 : jArr) {
            j += j2;
        }
        return j;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int argmax(double[] dArr) {
        double d = -1.7976931348623157E308d;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > d) {
                d = dArr[i];
                arrayList = new ArrayList();
                arrayList.add(Integer.valueOf(i));
            } else if (dArr[i] == d) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return arrayList.size() == 1 ? ((Integer) arrayList.get(0)).intValue() : ((Integer) arrayList.get(new Random().nextInt(arrayList.size()))).intValue();
    }

    public void assign(Node<T> node, long[] jArr) {
        int sum = (int) ((this.imbalance * (sum(jArr) + 1)) / this.partitions);
        double[] dArr = new double[this.partitions];
        double[] dArr2 = new double[this.partitions];
        for (int i = 0; i < this.partitions; i++) {
            dArr[i] = this.similarity.similarity(this.medoids.get(i).value, node.value);
        }
        for (int i2 = 0; i2 < this.partitions; i2++) {
            dArr2[i2] = dArr[i2] * (1 - (jArr[i2] / sum));
        }
        int argmax = argmax(dArr2);
        jArr[argmax] = jArr[argmax] + 1;
        node.setAttribute(PARTITION_KEY, Integer.valueOf(argmax));
    }

    public NodePartitioner getInternalPartitioner() {
        return this.internal_partitioner;
    }

    public void computeNewMedoids(JavaRDD<Graph<T>> javaRDD) {
        this.medoids = javaRDD.map(new ComputeMedoids()).collect();
    }

    public void computeNewMedoids(JavaPairRDD<Node<T>, NeighborList> javaPairRDD) {
        this.medoids = javaPairRDD.mapPartitions(new FlatMapFunction<Iterator<Tuple2<Node<T>, NeighborList>>, Node<T>>() { // from class: info.debatty.spark.knngraphs.BalancedKMedoidsPartitioner.1
            public Iterable<Node<T>> call(Iterator<Tuple2<Node<T>, NeighborList>> it) throws Exception {
                Graph graph = new Graph();
                while (it.hasNext()) {
                    Tuple2<Node<T>, NeighborList> next = it.next();
                    graph.put((Node) next._1(), (NeighborList) next._2());
                }
                if (graph.size() == 0) {
                    return new ArrayList();
                }
                ArrayList stronglyConnectedComponents = graph.stronglyConnectedComponents();
                int i = 0;
                Graph graph2 = (Graph) stronglyConnectedComponents.get(0);
                Iterator it2 = stronglyConnectedComponents.iterator();
                while (it2.hasNext()) {
                    Graph graph3 = (Graph) it2.next();
                    if (graph3.size() > i) {
                        graph2 = graph3;
                        i = graph3.size();
                    }
                }
                int i2 = Integer.MAX_VALUE;
                Node node = (Node) graph2.getNodes().iterator().next();
                for (Node node2 : graph2.getNodes()) {
                    int largestDistance = new Dijkstra(graph2, node2).getLargestDistance();
                    if (largestDistance != 0 && largestDistance < i2) {
                        i2 = largestDistance;
                        node = node2;
                    }
                }
                ArrayList arrayList = new ArrayList(1);
                arrayList.add(node);
                return arrayList;
            }
        }).collect();
    }
}
