package org.tribuo.math.neighbour.kdtree;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.tribuo.math.distance.Distance;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.neighbour.NeighboursQuery;

/* loaded from: input_file:org/tribuo/math/neighbour/kdtree/KDTree.class */
public final class KDTree implements NeighboursQuery {
    private final SGDVector[] data;
    private final int numThreads;
    private final DimensionNode root;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/math/neighbour/kdtree/KDTree$DistanceIntAndVectorBoundedMinHeap.class */
    public static final class DistanceIntAndVectorBoundedMinHeap {
        private final HashSet<Integer> ids = new HashSet<>();
        final int size;
        private final PriorityQueue<MutableDistIntAndVectorTuple> queue;

        DistanceIntAndVectorBoundedMinHeap(int i) {
            this.size = i;
            this.queue = new PriorityQueue<>(i);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void boundedOffer(IntAndVector intAndVector, double d) {
            if (this.ids.contains(Integer.valueOf(intAndVector.idx))) {
                return;
            }
            if (this.queue.size() < this.size) {
                this.queue.offer(new MutableDistIntAndVectorTuple(d, intAndVector));
                this.ids.add(Integer.valueOf(intAndVector.idx));
            } else if (Double.compare(d, this.queue.peek().dist) < 0) {
                MutableDistIntAndVectorTuple poll = poll();
                poll.dist = d;
                poll.intAndVector = intAndVector;
                this.queue.offer(poll);
                this.ids.add(Integer.valueOf(intAndVector.idx));
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public MutableDistIntAndVectorTuple peek() {
            return this.queue.peek();
        }

        MutableDistIntAndVectorTuple poll() {
            MutableDistIntAndVectorTuple poll = this.queue.poll();
            this.ids.remove(Integer.valueOf(poll.intAndVector.idx));
            return poll;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public boolean isFull() {
            return this.queue.size() == this.size;
        }

        boolean isEmpty() {
            return this.queue.isEmpty();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/math/neighbour/kdtree/KDTree$IntAndVector.class */
    public static final class IntAndVector {
        final int idx;
        final SGDVector vector;

        public IntAndVector(int i, SGDVector sGDVector) {
            this.idx = i;
            this.vector = sGDVector;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/math/neighbour/kdtree/KDTree$MutableDistIntAndVectorTuple.class */
    public static final class MutableDistIntAndVectorTuple implements Comparable<MutableDistIntAndVectorTuple> {
        double dist;
        IntAndVector intAndVector;

        public MutableDistIntAndVectorTuple(double d, IntAndVector intAndVector) {
            this.dist = d;
            this.intAndVector = intAndVector;
        }

        @Override // java.lang.Comparable
        public int compareTo(MutableDistIntAndVectorTuple mutableDistIntAndVectorTuple) {
            return Double.compare(mutableDistIntAndVectorTuple.dist, this.dist);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/neighbour/kdtree/KDTree$SingleQueryRunnable.class */
    public final class SingleQueryRunnable implements Runnable {
        private final SGDVector point;
        private final int k;
        private final int index;
        final List<Pair<Integer, Double>>[] indexDistancePairListArray;

        SingleQueryRunnable(int i, SGDVector sGDVector, int i2, List<Pair<Integer, Double>>[] listArr) {
            this.point = sGDVector;
            this.k = i2;
            this.index = i;
            this.indexDistancePairListArray = listArr;
        }

        @Override // java.lang.Runnable
        public void run() {
            this.indexDistancePairListArray[this.index] = KDTree.this.query(this.point, this.k);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public KDTree(SGDVector[] sGDVectorArr, Distance distance, int i) {
        this.data = sGDVectorArr;
        this.numThreads = i;
        int size = sGDVectorArr[0].size();
        IntAndVector[] intAndVectorArr = new IntAndVector[sGDVectorArr.length];
        for (int i2 = 0; i2 < sGDVectorArr.length; i2++) {
            intAndVectorArr[i2] = new IntAndVector(i2, sGDVectorArr[i2]);
            if (sGDVectorArr[i2].size() != size) {
                throw new IllegalArgumentException("All the SGDVectors must be the same size.");
            }
        }
        this.root = generateTree(0, size - 1, intAndVectorArr, 0, sGDVectorArr.length - 1, distance);
    }

    @Override // org.tribuo.math.neighbour.NeighboursQuery
    public List<Pair<Integer, Double>> query(SGDVector sGDVector, int i) {
        DistanceIntAndVectorBoundedMinHeap distanceIntAndVectorBoundedMinHeap = new DistanceIntAndVectorBoundedMinHeap(i);
        initializeQueue(sGDVector, distanceIntAndVectorBoundedMinHeap);
        this.root.nearest(sGDVector, distanceIntAndVectorBoundedMinHeap, false);
        Pair[] pairArr = new Pair[i];
        int i2 = 1;
        while (!distanceIntAndVectorBoundedMinHeap.isEmpty()) {
            MutableDistIntAndVectorTuple poll = distanceIntAndVectorBoundedMinHeap.poll();
            pairArr[i - i2] = new Pair(Integer.valueOf(poll.intAndVector.idx), Double.valueOf(poll.dist));
            i2++;
        }
        return Arrays.asList(pairArr);
    }

    @Override // org.tribuo.math.neighbour.NeighboursQuery
    public List<List<Pair<Integer, Double>>> query(SGDVector[] sGDVectorArr, int i) {
        int length = sGDVectorArr.length;
        List[] listArr = new List[length];
        if (this.numThreads == 1) {
            for (int i2 = 0; i2 < length; i2++) {
                listArr[i2] = query(sGDVectorArr[i2], i);
            }
        } else {
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
            for (int i3 = 0; i3 < length; i3++) {
                newFixedThreadPool.execute(new SingleQueryRunnable(i3, sGDVectorArr[i3], i, listArr));
            }
            newFixedThreadPool.shutdown();
            try {
                if (!newFixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES)) {
                    throw new RuntimeException("Parallel execution failed");
                }
            } catch (InterruptedException e) {
                throw new RuntimeException("Parallel execution failed", e);
            }
        }
        return Arrays.asList(listArr);
    }

    @Override // org.tribuo.math.neighbour.NeighboursQuery
    public List<List<Pair<Integer, Double>>> queryAll(int i) {
        return query(this.data, i);
    }

    private static DimensionNode generateTree(int i, int i2, IntAndVector[] intAndVectorArr, int i3, int i4, Distance distance) {
        if (i4 < i3) {
            return null;
        }
        if (i4 == i3) {
            return new DimensionNode(i, intAndVectorArr[i3], distance);
        }
        int i5 = 1 + ((i4 - i3) / 2);
        setMedian(intAndVectorArr, i5, i3, i4, i);
        DimensionNode dimensionNode = new DimensionNode(i, intAndVectorArr[(i3 + i5) - 1], distance);
        int i6 = i + 1;
        if (i6 > i2) {
            i6 = 0;
        }
        dimensionNode.setBelow(generateTree(i6, i2, intAndVectorArr, i3, (i3 + i5) - 2, distance));
        dimensionNode.setAbove(generateTree(i6, i2, intAndVectorArr, i3 + i5, i4, distance));
        return dimensionNode;
    }

    private static void setMedian(IntAndVector[] intAndVectorArr, int i, int i2, int i3, int i4) {
        while (true) {
            int partitionOnIndex = partitionOnIndex(intAndVectorArr, i2, i3, getPivotPointIndex(intAndVectorArr, i2, i3, i4), i4);
            if ((i2 + i) - 1 == partitionOnIndex) {
                return;
            }
            if ((i2 + i) - 1 < partitionOnIndex) {
                i3 = partitionOnIndex - 1;
            } else {
                i -= (partitionOnIndex + 1) - i2;
                i2 = partitionOnIndex + 1;
            }
        }
    }

    private static int getPivotPointIndex(IntAndVector[] intAndVectorArr, int i, int i2, int i3) {
        int i4 = (i + i2) / 2;
        int i5 = i;
        if (compareByDimension(intAndVectorArr[i5], intAndVectorArr[i4], i3) >= 0) {
            i5 = i4;
            i4 = i;
        }
        return compareByDimension(intAndVectorArr[i2], intAndVectorArr[i5], i3) <= 0 ? i5 : compareByDimension(intAndVectorArr[i2], intAndVectorArr[i4], i3) <= 0 ? i2 : i4;
    }

    private static int partitionOnIndex(IntAndVector[] intAndVectorArr, int i, int i2, int i3, int i4) {
        IntAndVector intAndVector = intAndVectorArr[i3];
        swap(intAndVectorArr, i2, i3);
        int i5 = i;
        for (int i6 = i; i6 < i2; i6++) {
            if (compareByDimension(intAndVectorArr[i6], intAndVector, i4) <= 0) {
                swap(intAndVectorArr, i6, i5);
                i5++;
            }
        }
        swap(intAndVectorArr, i2, i5);
        return i5;
    }

    private static void swap(IntAndVector[] intAndVectorArr, int i, int i2) {
        if (i == i2) {
            return;
        }
        IntAndVector intAndVector = intAndVectorArr[i];
        intAndVectorArr[i] = intAndVectorArr[i2];
        intAndVectorArr[i2] = intAndVector;
    }

    private void initializeQueue(SGDVector sGDVector, DistanceIntAndVectorBoundedMinHeap distanceIntAndVectorBoundedMinHeap) {
        approximateParentNode(sGDVector).nearest(sGDVector, distanceIntAndVectorBoundedMinHeap, true);
    }

    public DimensionNode approximateParentNode(SGDVector sGDVector) {
        DimensionNode dimensionNode = this.root;
        DimensionNode dimensionNode2 = dimensionNode;
        while (dimensionNode != null) {
            if (dimensionNode.getBelow() != null && dimensionNode.getAbove() != null) {
                dimensionNode2 = dimensionNode;
            }
            DimensionNode below = dimensionNode.isBelow(sGDVector) ? dimensionNode.getBelow() : dimensionNode.getAbove();
            if (below == null) {
                break;
            }
            dimensionNode = below;
        }
        return dimensionNode2;
    }

    private static int compareByDimension(IntAndVector intAndVector, IntAndVector intAndVector2, int i) {
        return Double.compare(intAndVector.vector.get(i), intAndVector2.vector.get(i));
    }
}
