package org.neo4j.gds.embeddings.fastrp;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import org.neo4j.graphalgo.Algorithm;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimation;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimations;
import org.neo4j.graphalgo.core.utils.mem.MemoryUsage;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;
import org.neo4j.graphalgo.utils.CloseableThreadLocal;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/fastrp/FastRP.class */
public class FastRP extends Algorithm<FastRP, FastRP> {
    private static final int SPARSITY = 3;
    private final Graph graph;
    private final int concurrency;
    private final float normalizationStrength;
    private final HugeObjectArray<float[]> embeddings;
    private final HugeObjectArray<float[]> embeddingA;
    private final HugeObjectArray<float[]> embeddingB;
    private final EmbeddingCombiner embeddingCombiner;
    private final int embeddingDimension;
    private final int iterations;
    private final List<Double> iterationWeights;

    /* loaded from: input_file:org/neo4j/gds/embeddings/fastrp/FastRP$EmbeddingCombiner.class */
    private interface EmbeddingCombiner {
        void combine(float[] fArr, float[] fArr2, double d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/embeddings/fastrp/FastRP$HighQualityRandom.class */
    public static class HighQualityRandom extends Random {
        private long u;
        private long v;
        private long w;

        public HighQualityRandom() {
            this(System.nanoTime() + (13 * Thread.currentThread().getId()));
        }

        public HighQualityRandom(long j) {
            this.v = 4101842887655102017L;
            this.w = 1L;
            this.u = j ^ this.v;
            nextLong();
            this.v = this.u;
            nextLong();
            this.w = this.v;
            nextLong();
        }

        @Override // java.util.Random
        public long nextLong() {
            this.u = (this.u * 2862933555777941757L) + 7046029254386353087L;
            this.v ^= this.v >>> 17;
            this.v ^= this.v << 31;
            this.v ^= this.v >>> 8;
            this.w = (4294957665L * this.w) + (this.w >>> 32);
            long j = this.u ^ (this.u << 21);
            long j2 = j ^ (j >>> 35);
            return ((j2 ^ (j2 << 4)) + this.v) ^ this.w;
        }

        @Override // java.util.Random
        protected int next(int i) {
            return (int) (nextLong() >>> (64 - i));
        }
    }

    static MemoryEstimation memoryEstimation(FastRPBaseConfig fastRPBaseConfig) {
        return MemoryEstimations.builder(FastRP.class).add("embeddings", HugeObjectArray.memoryEstimation(MemoryUsage.sizeOfFloatArray(fastRPBaseConfig.embeddingDimension()))).add("embeddingA", HugeObjectArray.memoryEstimation(MemoryUsage.sizeOfFloatArray(fastRPBaseConfig.embeddingDimension()))).add("embeddingB", HugeObjectArray.memoryEstimation(MemoryUsage.sizeOfFloatArray(fastRPBaseConfig.embeddingDimension()))).build();
    }

    public FastRP(Graph graph, FastRPBaseConfig fastRPBaseConfig, ProgressLogger progressLogger, AllocationTracker allocationTracker) {
        this.graph = graph;
        this.progressLogger = progressLogger;
        this.embeddings = HugeObjectArray.newArray(float[].class, graph.nodeCount(), allocationTracker);
        this.embeddingA = HugeObjectArray.newArray(float[].class, graph.nodeCount(), allocationTracker);
        this.embeddingB = HugeObjectArray.newArray(float[].class, graph.nodeCount(), allocationTracker);
        this.embeddingDimension = fastRPBaseConfig.embeddingDimension();
        this.iterations = fastRPBaseConfig.iterations();
        this.iterationWeights = fastRPBaseConfig.iterationWeights();
        this.normalizationStrength = fastRPBaseConfig.normalizationStrength();
        this.concurrency = fastRPBaseConfig.concurrency();
        this.embeddingCombiner = graph.hasRelationshipProperty() ? this::addArrayValuesWeighted : (fArr, fArr2, d) -> {
            addArrayValues(fArr, fArr2);
        };
        this.embeddings.setAll(j -> {
            return new float[this.embeddingDimension];
        });
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public FastRP m1compute() {
        this.progressLogger.logMessage(":: Start");
        initRandomVectors();
        propagateEmbeddings();
        this.progressLogger.logMessage(":: Finished");
        return m0me();
    }

    public HugeObjectArray<float[]> embeddings() {
        return this.embeddings;
    }

    HugeObjectArray<float[]> currentEmbedding(int i) {
        return i % 2 == 0 ? this.embeddingA : this.embeddingB;
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public FastRP m0me() {
        return this;
    }

    public void release() {
        this.embeddingA.release();
        this.embeddingB.release();
    }

    void initRandomVectors() {
        double d = 0.1666666716337204d;
        float sqrt = (float) Math.sqrt(3.0d);
        float sqrt2 = (float) Math.sqrt(this.embeddingDimension);
        ThreadLocal withInitial = ThreadLocal.withInitial(HighQualityRandom::new);
        this.progressLogger.logMessage("Initialising Random Vectors :: Start");
        ParallelUtil.parallelForEachNode(this.graph, this.concurrency, j -> {
            int degree = this.graph.degree(j);
            this.embeddingB.set(j, computeRandomVector((Random) withInitial.get(), d, ((degree == 0 ? 1.0f : (float) Math.pow(degree, this.normalizationStrength)) * sqrt) / sqrt2));
            this.embeddingA.set(j, new float[this.embeddingDimension]);
            this.progressLogger.logProgress();
        });
        this.progressLogger.logMessage("Initialising Random Vectors :: Finished");
    }

    void propagateEmbeddings() {
        for (int i = 0; i < this.iterations; i++) {
            this.progressLogger.reset(this.graph.relationshipCount());
            this.progressLogger.logMessage(StringFormatting.formatWithLocale("Iteration %s :: Start", new Object[]{Integer.valueOf(i + 1)}));
            HugeObjectArray<float[]> hugeObjectArray = i % 2 == 0 ? this.embeddingA : this.embeddingB;
            HugeObjectArray<float[]> hugeObjectArray2 = i % 2 == 0 ? this.embeddingB : this.embeddingA;
            double doubleValue = this.iterationWeights.get(i).doubleValue();
            Graph graph = this.graph;
            Objects.requireNonNull(graph);
            CloseableThreadLocal withInitial = CloseableThreadLocal.withInitial(graph::concurrentCopy);
            try {
                ParallelUtil.parallelForEachNode(this.graph, this.concurrency, j -> {
                    float[] fArr = (float[]) this.embeddings.get(j);
                    float[] fArr2 = (float[]) hugeObjectArray.get(j);
                    Arrays.fill(fArr2, 0.0f);
                    ((Graph) withInitial.get()).forEachRelationship(j, 1.0d, (j, j2, d) -> {
                        this.embeddingCombiner.combine(fArr2, (float[]) hugeObjectArray2.get(j2), d);
                        return true;
                    });
                    int degree = this.graph.degree(j);
                    multiplyArrayValues(fArr2, 1.0f / (degree == 0 ? 1 : degree));
                    l2Normalize(fArr2);
                    updateEmbeddings(doubleValue, fArr, fArr2);
                    this.progressLogger.logProgress(degree);
                });
                if (withInitial != null) {
                    withInitial.close();
                }
                this.progressLogger.logMessage(StringFormatting.formatWithLocale("Iteration %s :: Finished", new Object[]{Integer.valueOf(i + 1)}));
            } catch (Throwable th) {
                if (withInitial != null) {
                    try {
                        withInitial.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
    }

    private float[] computeRandomVector(Random random, double d, float f) {
        float[] fArr = new float[this.embeddingDimension];
        for (int i = 0; i < this.embeddingDimension; i++) {
            fArr[i] = computeRandomEntry(random, d, f);
        }
        return fArr;
    }

    private float computeRandomEntry(Random random, double d, float f) {
        double nextDouble = random.nextDouble();
        if (nextDouble < d) {
            return f;
        }
        if (nextDouble < d * 2.0d) {
            return -f;
        }
        return 0.0f;
    }

    private void updateEmbeddings(double d, float[] fArr, float[] fArr2) {
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) (fArr[r1] + (d * fArr2[i]));
        }
    }

    private void addArrayValues(float[] fArr, float[] fArr2) {
        for (int i = 0; i < fArr.length; i++) {
            int i2 = i;
            fArr[i2] = fArr[i2] + fArr2[i];
        }
    }

    private void addArrayValuesWeighted(float[] fArr, float[] fArr2, double d) {
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) Math.fma(fArr2[i], d, fArr[i]);
        }
    }

    private void multiplyArrayValues(float[] fArr, double d) {
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) (fArr[r1] * d);
        }
    }

    static void l2Normalize(float[] fArr) {
        double d = 0.0d;
        for (double d2 : fArr) {
            d += d2 * d2;
        }
        double sqrt = 1.0d / (d == 0.0d ? 1.0d : Math.sqrt(d));
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) (fArr[r1] * sqrt);
        }
    }
}
