/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.kge.scorers;

import com.carrotsearch.hppc.DoubleArrayList;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.ml.kge.scorers.LinkScorer;

public class FloatDistMultLinkScorer
implements LinkScorer {
    NodePropertyValues embeddings;
    double[] relationshipTypeEmbedding;
    long currentSourceNode;
    float[] currentCandidateTarget;

    FloatDistMultLinkScorer(NodePropertyValues embeddings, DoubleArrayList relationshipTypeEmbedding) {
        this.embeddings = embeddings;
        this.relationshipTypeEmbedding = relationshipTypeEmbedding.toArray();
        this.currentCandidateTarget = new float[this.relationshipTypeEmbedding.length];
    }

    @Override
    public void init(long sourceNode) {
        this.currentSourceNode = sourceNode;
        float[] currentSource = this.embeddings.floatArrayValue(this.currentSourceNode);
        for (int i = 0; i < this.relationshipTypeEmbedding.length; ++i) {
            this.currentCandidateTarget[i] = (float)((double)currentSource[i] * this.relationshipTypeEmbedding[i]);
        }
    }

    @Override
    public double computeScore(long targetNode) {
        double res = 0.0;
        float[] targetVector = this.embeddings.floatArrayValue(targetNode);
        for (int i = 0; i < this.currentCandidateTarget.length; ++i) {
            res += (double)(this.currentCandidateTarget[i] * targetVector[i]);
        }
        return res;
    }

    @Override
    public void close() throws Exception {
    }
}

