package org.neo4j.gds.kmeans;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.Intersections;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

/* loaded from: input_file:org/neo4j/gds/kmeans/Kmeans.class */
public class Kmeans extends Algorithm<KmeansResult> {
    private static final int UNASSIGNED = -1;
    private final HugeIntArray communities;
    private final Graph graph;
    private final int k;
    private final int concurrency;
    private final ExecutorService executorService;
    private final SplittableRandom random;
    private final NodePropertyValues nodePropertyValues;
    private final int dimensions;
    private final KmeansIterationStopper kmeansIterationStopper;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/kmeans/Kmeans$KmeansTask.class */
    public static final class KmeansTask implements Runnable {
        private final ProgressTracker progressTracker;
        private final Partition partition;
        private final double[][] communityCoordinateSums;
        private final NodePropertyValues nodePropertyValues;
        private final HugeIntArray communities;
        private final long[] communitySizes;
        private final double[][] clusterCenters;
        private final int k;
        private final int dimensions;
        private long swaps;

        KmeansTask(double[][] dArr, NodePropertyValues nodePropertyValues, HugeIntArray hugeIntArray, int i, int i2, Partition partition, ProgressTracker progressTracker) {
            this.progressTracker = progressTracker;
            this.partition = partition;
            this.clusterCenters = dArr;
            this.communityCoordinateSums = new double[i][i2];
            this.communitySizes = new long[i];
            this.k = i;
            this.dimensions = i2;
            this.nodePropertyValues = nodePropertyValues;
            this.communities = hugeIntArray;
        }

        double[] getCenterContribution(int i) {
            return this.communityCoordinateSums[i];
        }

        long getNumAssignedAtCenter(int i) {
            return this.communitySizes[i];
        }

        long getSwaps() {
            return this.swaps;
        }

        private double euclidean(double[] dArr, double[] dArr2) {
            return Math.sqrt(Intersections.sumSquareDelta(dArr, dArr2, dArr2.length));
        }

        @Override // java.lang.Runnable
        public void run() {
            long startNode = this.partition.startNode();
            long nodeCount = startNode + this.partition.nodeCount();
            this.swaps = 0L;
            for (int i = 0; i < this.k; i++) {
                this.communitySizes[i] = 0;
                for (int i2 = 0; i2 < this.dimensions; i2++) {
                    this.communityCoordinateSums[i][i2] = 0.0d;
                }
            }
            long j = startNode;
            while (true) {
                long j2 = j;
                if (j2 >= nodeCount) {
                    return;
                }
                int i3 = 0;
                double[] doubleArrayValue = this.nodePropertyValues.doubleArrayValue(j2);
                double d = Double.MAX_VALUE;
                for (int i4 = 0; i4 < this.k; i4++) {
                    double euclidean = euclidean(doubleArrayValue, this.clusterCenters[i4]);
                    if (Double.compare(euclidean, d) < 0) {
                        d = euclidean;
                        i3 = i4;
                    }
                }
                long[] jArr = this.communitySizes;
                int i5 = i3;
                jArr[i5] = jArr[i5] + 1;
                if (i3 != this.communities.get(j2)) {
                    this.swaps++;
                }
                this.communities.set(j2, i3);
                for (int i6 = 0; i6 < this.dimensions; i6++) {
                    double[] dArr = this.communityCoordinateSums[i3];
                    int i7 = i6;
                    dArr[i7] = dArr[i7] + doubleArrayValue[i6];
                }
                j = j2 + 1;
            }
        }
    }

    public static Kmeans createKmeans(Graph graph, KmeansBaseConfig kmeansBaseConfig, KmeansContext kmeansContext) {
        return new Kmeans(kmeansContext.progressTracker(), kmeansContext.executor(), graph, kmeansBaseConfig.k(), kmeansBaseConfig.concurrency(), kmeansBaseConfig.maxIterations(), kmeansBaseConfig.deltaThreshold(), graph.nodeProperties(kmeansBaseConfig.nodeProperty()), getSplittableRandom(kmeansBaseConfig.randomSeed()));
    }

    Kmeans(ProgressTracker progressTracker, ExecutorService executorService, Graph graph, int i, int i2, int i3, double d, NodePropertyValues nodePropertyValues, SplittableRandom splittableRandom) {
        super(progressTracker);
        this.executorService = executorService;
        this.graph = graph;
        this.k = i;
        this.concurrency = i2;
        this.random = splittableRandom;
        this.communities = HugeIntArray.newArray(graph.nodeCount());
        this.nodePropertyValues = nodePropertyValues;
        this.dimensions = nodePropertyValues.doubleArrayValue(0L).length;
        this.kmeansIterationStopper = new KmeansIterationStopper(d, i3, graph.nodeCount());
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public KmeansResult m22compute() {
        this.progressTracker.beginSubTask();
        if (this.k > this.graph.nodeCount()) {
            this.progressTracker.logWarning("Number of requested clusters is larger than the number of nodes.");
            this.communities.setAll(j -> {
                return (int) j;
            });
            this.progressTracker.endSubTask();
            return ImmutableKmeansResult.of(this.communities);
        }
        long nodeCount = this.graph.nodeCount();
        double[][] dArr = new double[this.k][this.dimensions];
        this.communities.setAll(j2 -> {
            return -1;
        });
        List<KmeansTask> rangePartition = PartitionUtils.rangePartition(this.concurrency, nodeCount, partition -> {
            return new KmeansTask(dArr, this.nodePropertyValues, this.communities, this.k, this.dimensions, partition, this.progressTracker);
        }, Optional.of(Integer.valueOf(((int) nodeCount) / this.concurrency)));
        int size = rangePartition.size();
        if (!$assertionsDisabled && size > this.concurrency) {
            throw new AssertionError();
        }
        assignCenters(dArr, new KmeansUniformSampler().sampleClusters(this.random, this.nodePropertyValues, nodeCount, this.k), this.dimensions);
        int i = 0;
        while (true) {
            long j3 = 0;
            ParallelUtil.runWithConcurrency(this.concurrency, rangePartition, this.executorService);
            Iterator<KmeansTask> it = rangePartition.iterator();
            while (it.hasNext()) {
                j3 += it.next().getSwaps();
            }
            i++;
            if (this.kmeansIterationStopper.shouldQuit(j3, i)) {
                this.progressTracker.endSubTask();
                return ImmutableKmeansResult.of(this.communities);
            }
            recomputeCenters(dArr, rangePartition);
        }
    }

    private void recomputeCenters(double[][] dArr, List<KmeansTask> list) {
        long[] jArr = new long[this.k];
        for (int i = 0; i < this.k; i++) {
            for (int i2 = 0; i2 < this.dimensions; i2++) {
                dArr[i][i2] = 0.0d;
            }
        }
        for (KmeansTask kmeansTask : list) {
            for (int i3 = 0; i3 < this.k; i3++) {
                double[] centerContribution = kmeansTask.getCenterContribution(i3);
                int i4 = i3;
                jArr[i4] = jArr[i4] + kmeansTask.getNumAssignedAtCenter(i3);
                for (int i5 = 0; i5 < this.dimensions; i5++) {
                    double[] dArr2 = dArr[i3];
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] + centerContribution[i5];
                }
            }
        }
        for (int i7 = 0; i7 < this.k; i7++) {
            for (int i8 = 0; i8 < this.dimensions; i8++) {
                double[] dArr3 = dArr[i7];
                int i9 = i8;
                dArr3[i9] = dArr3[i9] / jArr[i7];
            }
        }
    }

    public void release() {
    }

    @NotNull
    private static SplittableRandom getSplittableRandom(Optional<Long> optional) {
        return (SplittableRandom) optional.map((v1) -> {
            return new SplittableRandom(v1);
        }).orElseGet(SplittableRandom::new);
    }

    private void assignCenters(double[][] dArr, List<Long> list, int i) {
        int i2 = 0;
        Iterator<Long> it = list.iterator();
        while (it.hasNext()) {
            double[] doubleArrayValue = this.nodePropertyValues.doubleArrayValue(it.next().longValue());
            for (int i3 = 0; i3 < i; i3++) {
                dArr[i2][i3] = doubleArrayValue[i3];
            }
            i2++;
        }
    }

    static {
        $assertionsDisabled = !Kmeans.class.desiredAssertionStatus();
    }
}
