package org.neo4j.gds.embeddings.hashgnn;

import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.BitSetIterator;
import java.util.List;
import java.util.SplittableRandom;
import java.util.stream.Collectors;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/neo4j/gds/embeddings/hashgnn/DensifyTask.class */
public class DensifyTask implements Runnable {
    static final int SPARSITY = 3;
    static final double ENTRY_PROBABILITY = 0.16666666666666666d;
    private final Partition partition;
    private final HashGNNConfig config;
    private final HugeObjectArray<double[]> denseFeatures;
    private final HugeObjectArray<BitSet> binaryFeatures;
    private final float[][] projectionMatrix;
    private final ProgressTracker progressTracker;

    DensifyTask(Partition partition, HashGNNConfig hashGNNConfig, HugeObjectArray<double[]> hugeObjectArray, HugeObjectArray<BitSet> hugeObjectArray2, float[][] fArr, ProgressTracker progressTracker) {
        this.partition = partition;
        this.config = hashGNNConfig;
        this.denseFeatures = hugeObjectArray;
        this.binaryFeatures = hugeObjectArray2;
        this.projectionMatrix = fArr;
        this.progressTracker = progressTracker;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static HugeObjectArray<double[]> compute(Graph graph, List<Partition> list, HashGNNConfig hashGNNConfig, SplittableRandom splittableRandom, HugeObjectArray<BitSet> hugeObjectArray, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        progressTracker.beginSubTask("Densify output embeddings");
        HugeObjectArray<double[]> newArray = HugeObjectArray.newArray(double[].class, graph.nodeCount());
        float[][] projectionMatrix = projectionMatrix(splittableRandom, hashGNNConfig.outputDimension().orElseThrow().intValue(), (int) ((BitSet) hugeObjectArray.get(0L)).capacity());
        RunWithConcurrency.builder().concurrency(hashGNNConfig.concurrency()).tasks((List) list.stream().map(partition -> {
            return new DensifyTask(partition, hashGNNConfig, newArray, hugeObjectArray, projectionMatrix, progressTracker);
        }).collect(Collectors.toList())).terminationFlag(terminationFlag).run();
        progressTracker.endSubTask("Densify output embeddings");
        return newArray;
    }

    private static float[][] projectionMatrix(SplittableRandom splittableRandom, int i, int i2) {
        float sqrt = ((float) Math.sqrt(3.0d)) / ((float) Math.sqrt(i));
        float[][] fArr = new float[i2][i];
        for (int i3 = 0; i3 < i2; i3++) {
            fArr[i3] = new float[i];
            for (int i4 = 0; i4 < i; i4++) {
                fArr[i3][i4] = computeRandomEntry(splittableRandom, sqrt);
            }
        }
        return fArr;
    }

    private static float computeRandomEntry(SplittableRandom splittableRandom, float f) {
        double nextDouble = splittableRandom.nextDouble();
        if (nextDouble < ENTRY_PROBABILITY) {
            return f;
        }
        if (nextDouble < 0.3333333333333333d) {
            return -f;
        }
        return 0.0f;
    }

    @Override // java.lang.Runnable
    public void run() {
        int length = this.projectionMatrix[0].length;
        this.partition.consume(j -> {
            BitSet bitSet = (BitSet) this.binaryFeatures.get(j);
            double[] dArr = new double[this.config.outputDimension().orElseThrow().intValue()];
            BitSetIterator it = bitSet.iterator();
            int nextSetBit = it.nextSetBit();
            while (true) {
                int i = nextSetBit;
                if (i == -1) {
                    this.denseFeatures.set(j, dArr);
                    return;
                }
                float[] fArr = this.projectionMatrix[i];
                for (int i2 = 0; i2 < length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + fArr[i2];
                }
                nextSetBit = it.nextSetBit();
            }
        });
        this.progressTracker.logProgress(this.partition.nodeCount());
    }
}
