package org.deeplearning4j.graph.models.embeddings;

import java.beans.ConstructorProperties;
import java.util.Comparator;
import java.util.PriorityQueue;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.models.GraphVectors;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/graph/models/embeddings/GraphVectorsImpl.class */
public class GraphVectorsImpl<V, E> implements GraphVectors<V, E> {
    protected IGraph<V, E> graph;
    protected GraphVectorLookupTable lookupTable;

    /* loaded from: input_file:org/deeplearning4j/graph/models/embeddings/GraphVectorsImpl$PairComparator.class */
    private static class PairComparator implements Comparator<Pair<Double, Integer>> {
        private PairComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Pair<Double, Integer> pair, Pair<Double, Integer> pair2) {
            return -Double.compare(((Double) pair.getFirst()).doubleValue(), ((Double) pair2.getFirst()).doubleValue());
        }
    }

    @Override // org.deeplearning4j.graph.models.GraphVectors
    public IGraph<V, E> getGraph() {
        return this.graph;
    }

    @Override // org.deeplearning4j.graph.models.GraphVectors
    public int numVertices() {
        return this.lookupTable.getNumVertices();
    }

    @Override // org.deeplearning4j.graph.models.GraphVectors
    public int getVectorSize() {
        return this.lookupTable.vectorSize();
    }

    @Override // org.deeplearning4j.graph.models.GraphVectors
    public INDArray getVertexVector(Vertex<V> vertex) {
        return this.lookupTable.getVector(vertex.vertexID());
    }

    @Override // org.deeplearning4j.graph.models.GraphVectors
    public INDArray getVertexVector(int i) {
        return this.lookupTable.getVector(i);
    }

    @Override // org.deeplearning4j.graph.models.GraphVectors
    public int[] verticesNearest(int i, int i2) {
        INDArray dup = this.lookupTable.getVector(i).dup();
        double doubleValue = dup.norm2Number().doubleValue();
        PriorityQueue priorityQueue = new PriorityQueue(this.lookupTable.getNumVertices(), new PairComparator());
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        for (int i3 = 0; i3 < numVertices(); i3++) {
            if (i3 != i) {
                INDArray vector = this.lookupTable.getVector(i3);
                priorityQueue.add(new Pair(Double.valueOf(level1.dot(dup.length(), 1.0d, dup, vector) / (doubleValue * vector.norm2Number().doubleValue())), Integer.valueOf(i3)));
            }
        }
        int[] iArr = new int[i2];
        for (int i4 = 0; i4 < i2; i4++) {
            iArr[i4] = ((Integer) ((Pair) priorityQueue.remove()).getSecond()).intValue();
        }
        return iArr;
    }

    @Override // org.deeplearning4j.graph.models.GraphVectors
    public double similarity(Vertex<V> vertex, Vertex<V> vertex2) {
        return similarity(vertex.vertexID(), vertex2.vertexID());
    }

    @Override // org.deeplearning4j.graph.models.GraphVectors
    public double similarity(int i, int i2) {
        if (i == i2) {
            return 1.0d;
        }
        return Nd4j.getBlasWrapper().dot(Transforms.unitVec(getVertexVector(i)), Transforms.unitVec(getVertexVector(i2)));
    }

    @ConstructorProperties({"graph", "lookupTable"})
    public GraphVectorsImpl(IGraph<V, E> iGraph, GraphVectorLookupTable graphVectorLookupTable) {
        this.graph = iGraph;
        this.lookupTable = graphVectorLookupTable;
    }

    public GraphVectorsImpl() {
    }
}
