package info.debatty.spark.knngraphs.builder;

import info.debatty.java.graphs.Graph;
import info.debatty.java.graphs.NeighborList;
import info.debatty.java.graphs.Node;
import info.debatty.java.graphs.SimilarityInterface;
import info.debatty.spark.knngraphs.ApproximateSearch;
import info.debatty.spark.knngraphs.DistributedGraph;
import info.debatty.spark.knngraphs.partitioner.KMedoids;
import info.debatty.spark.knngraphs.partitioner.KMedoidsPartitioning;
import info.debatty.spark.knngraphs.partitioner.NodePartitioner;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

/* loaded from: input_file:info/debatty/spark/knngraphs/builder/Online.class */
public class Online<T> {
    private static final int DEFAULT_UPDATE_DEPTH = 2;
    private static final double DEFAULT_MEDOID_UPDATE_RATIO = 0.1d;
    private static final int ITERATIONS_BEFORE_CHECKPOINT = 100;
    private static final int RDDS_TO_CACHE = 3;
    private final JavaSparkContext spark_context;
    private final int k;
    private final SimilarityInterface<T> similarity;
    private final long[] partitions_size;
    private long nodes_before_update_medoids;
    private JavaRDD<Graph<T>> distributed_graph;
    private final ArrayList<Node<T>> medoids;
    private final double imbalance = 1.1d;
    private long nodes_added_or_removed = 0;
    private final double medoid_update_ratio = DEFAULT_MEDOID_UPDATE_RATIO;
    private final double search_speedup = 10.0d;
    private final int search_random_jumps = 2;
    private final double search_expansion = 1.2d;
    private final int update_depth = 2;
    private final LinkedList<JavaRDD<Graph<T>>> previous_rdds = new LinkedList<>();

    public Online(int i, SimilarityInterface<T> similarityInterface, JavaSparkContext javaSparkContext, JavaPairRDD<Node<T>, NeighborList> javaPairRDD, int i2) {
        this.nodes_before_update_medoids = 0L;
        this.similarity = similarityInterface;
        this.k = i;
        this.spark_context = javaSparkContext;
        KMedoidsPartitioning<T> partition = new KMedoids(similarityInterface, i2).partition((JavaPairRDD) javaPairRDD);
        this.medoids = partition.medoids;
        this.distributed_graph = DistributedGraph.toGraph(partition.graph, similarityInterface);
        this.distributed_graph.cache();
        this.distributed_graph.count();
        this.partitions_size = getPartitionsSize(this.distributed_graph);
        this.nodes_before_update_medoids = computeNodesBeforeUpdate();
    }

    public final void clean() {
        Iterator<JavaRDD<Graph<T>>> it = this.previous_rdds.iterator();
        while (it.hasNext()) {
            it.next().unpersist();
        }
    }

    public final long getSize() {
        long j = 0;
        for (long j2 : this.partitions_size) {
            j += j2;
        }
        return j;
    }

    public final void fastAdd(Node<T> node) {
        fastAdd(node, null);
    }

    public final void fastAdd(Node<T> node, StatisticsAccumulator statisticsAccumulator) {
        NeighborList search = new ApproximateSearch(this.distributed_graph).search(node, this.k, null, this.search_speedup, this.search_random_jumps, this.search_expansion);
        assign(node);
        long[] jArr = this.partitions_size;
        int intValue = ((Integer) node.getAttribute(NodePartitioner.PARTITION_KEY)).intValue();
        jArr[intValue] = jArr[intValue] + 1;
        JavaRDD<Graph<T>> cache = this.distributed_graph.map(new UpdateFunction(node, search, this.similarity, statisticsAccumulator, this.update_depth)).map(new AddNode(node, search)).cache();
        if (this.nodes_added_or_removed % 100 == 0) {
            cache.rdd().localCheckpoint();
        }
        this.previous_rdds.add(cache);
        if (this.nodes_added_or_removed > 3) {
            this.previous_rdds.pop().unpersist();
        }
        cache.count();
        this.distributed_graph = cache;
        this.nodes_added_or_removed++;
        this.nodes_before_update_medoids--;
        if (this.nodes_before_update_medoids == 0) {
            this.nodes_before_update_medoids = computeNodesBeforeUpdate();
        }
    }

    private final void assign(Node<T> node) {
        long sum = sum(this.partitions_size) + 1;
        int size = this.medoids.size();
        int i = (int) ((1.1d * sum) / size);
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        for (int i2 = 0; i2 < size; i2++) {
            dArr[i2] = this.similarity.similarity(this.medoids.get(i2).value, node.value);
        }
        for (int i3 = 0; i3 < size; i3++) {
            dArr2[i3] = dArr[i3] * (1 - (this.partitions_size[i3] / i));
        }
        int argmax = argmax(dArr2);
        long[] jArr = this.partitions_size;
        jArr[argmax] = jArr[argmax] + 1;
        node.setAttribute(NodePartitioner.PARTITION_KEY, Integer.valueOf(argmax));
    }

    private static long sum(long[] jArr) {
        long j = 0;
        for (long j2 : jArr) {
            j += j2;
        }
        return j;
    }

    private static int argmax(double[] dArr) {
        double d = -1.7976931348623157E308d;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > d) {
                d = dArr[i];
                arrayList = new ArrayList();
                arrayList.add(Integer.valueOf(i));
            } else if (dArr[i] == d) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return arrayList.size() == 1 ? ((Integer) arrayList.get(0)).intValue() : ((Integer) arrayList.get(new Random().nextInt(arrayList.size()))).intValue();
    }

    public final void fastRemove(Node<T> node, StatisticsAccumulator statisticsAccumulator) {
        List collect = this.distributed_graph.flatMap(new FindNodesToUpdate(node)).collect();
        LinkedList linkedList = new LinkedList();
        linkedList.add(node);
        linkedList.addAll(collect);
        LinkedList linkedList2 = new LinkedList(this.distributed_graph.flatMap(new SearchNeighbors(linkedList, this.update_depth)).collect());
        int i = 0;
        Iterator it = linkedList2.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Node node2 = (Node) it.next();
            if (node2.equals(node) && null != node2.getAttribute(NodePartitioner.PARTITION_KEY)) {
                i = ((Integer) node2.getAttribute(NodePartitioner.PARTITION_KEY)).intValue();
                break;
            }
        }
        while (linkedList2.contains(node)) {
            linkedList2.remove(node);
        }
        JavaRDD<Graph<T>> cache = this.distributed_graph.map(new RemoveUpdate(node, collect, linkedList2, statisticsAccumulator)).cache();
        long[] jArr = this.partitions_size;
        int i2 = i;
        jArr[i2] = jArr[i2] - 1;
        if (this.nodes_added_or_removed % 100 == 0) {
            cache.rdd().localCheckpoint();
        }
        this.previous_rdds.add(cache);
        if (this.nodes_added_or_removed > 2) {
            this.previous_rdds.pop().unpersist();
        }
        this.nodes_added_or_removed++;
        cache.count();
        this.distributed_graph = cache;
    }

    public final JavaRDD<Graph<T>> getDistributedGraph() {
        return this.distributed_graph;
    }

    public final JavaPairRDD<Node<T>, NeighborList> getGraph() {
        return this.distributed_graph.flatMapToPair(new MergeGraphs());
    }

    private long[] getPartitionsSize(JavaRDD<Graph<T>> javaRDD) {
        List collect = javaRDD.map(new SubgraphSizeFunction()).collect();
        long[] jArr = new long[collect.size()];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = ((Long) collect.get(i)).longValue();
        }
        return jArr;
    }

    private long computeNodesBeforeUpdate() {
        if (this.medoid_update_ratio == 0.0d) {
            return Long.MAX_VALUE;
        }
        return (long) (getSize() * this.medoid_update_ratio);
    }
}
