/*
 * Decompiled with CFR 0.152.
 */
package tagbio.umap;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import tagbio.umap.CsrMatrix;
import tagbio.umap.FlatTree;
import tagbio.umap.Hyperplane;
import tagbio.umap.MathUtils;
import tagbio.umap.Matrix;
import tagbio.umap.RandomProjectionTreeNode;
import tagbio.umap.SparseVector;
import tagbio.umap.UmapProgress;
import tagbio.umap.Utils;

final class RandomProjectionTree {
    private static final float EPS = 1.0E-8f;

    private RandomProjectionTree() {
    }

    private static Object[] angularRandomProjectionSplit(Matrix data, int[] indices, Random random) {
        int rightIndex;
        int dim = data.cols();
        int leftIndex = random.nextInt(indices.length);
        if (leftIndex == (rightIndex = random.nextInt(indices.length)) && ++rightIndex == indices.length) {
            rightIndex = 0;
        }
        int left = indices[leftIndex];
        int right = indices[rightIndex];
        float leftNorm = Utils.norm(data.row(left));
        float rightNorm = Utils.norm(data.row(right));
        if (Math.abs(leftNorm) < 1.0E-8f) {
            leftNorm = 1.0f;
        }
        if (Math.abs(rightNorm) < 1.0E-8f) {
            rightNorm = 1.0f;
        }
        float[] hyperplaneVector = new float[dim];
        for (int d = 0; d < dim; ++d) {
            hyperplaneVector[d] = data.get(left, d) / leftNorm - data.get(right, d) / rightNorm;
        }
        float hyperplaneNorm = Utils.norm(hyperplaneVector);
        if (Math.abs(hyperplaneNorm) < 1.0E-8f) {
            hyperplaneNorm = 1.0f;
        }
        int d = 0;
        while (d < dim) {
            int n = d++;
            hyperplaneVector[n] = hyperplaneVector[n] / hyperplaneNorm;
        }
        int nLeft = 0;
        int nRight = 0;
        boolean[] side = new boolean[indices.length];
        for (int i = 0; i < indices.length; ++i) {
            float margin = 0.0f;
            for (int d2 = 0; d2 < dim; ++d2) {
                margin += hyperplaneVector[d2] * data.get(indices[i], d2);
            }
            if (Math.abs(margin) < 1.0E-8f) {
                side[i] = random.nextBoolean();
                if (side[i]) {
                    ++nRight;
                    continue;
                }
                ++nLeft;
                continue;
            }
            if (margin > 0.0f) {
                side[i] = false;
                ++nLeft;
                continue;
            }
            side[i] = true;
            ++nRight;
        }
        int[] indicesLeft = new int[nLeft];
        int[] indicesRight = new int[nRight];
        nLeft = 0;
        nRight = 0;
        for (int i = 0; i < side.length; ++i) {
            if (side[i]) {
                indicesRight[nRight++] = indices[i];
                continue;
            }
            indicesLeft[nLeft++] = indices[i];
        }
        return new Object[]{indicesLeft, indicesRight, hyperplaneVector, null};
    }

    private static Object[] euclideanRandomProjectionSplit(Matrix data, int[] indices, Random random) {
        int rightIndex;
        int dim = data.cols();
        int leftIndex = random.nextInt(indices.length);
        if (leftIndex == (rightIndex = random.nextInt(indices.length)) && ++rightIndex == indices.length) {
            rightIndex = 0;
        }
        int left = indices[leftIndex];
        int right = indices[rightIndex];
        float hyperplaneOffset = 0.0f;
        float[] hyperplaneVector = new float[dim];
        for (int d = 0; d < dim; ++d) {
            float delta;
            float ld = data.get(left, d);
            float rd = data.get(right, d);
            hyperplaneVector[d] = delta = ld - rd;
            hyperplaneOffset -= delta * (ld + rd);
        }
        hyperplaneOffset /= 2.0f;
        int nLeft = 0;
        int nRight = 0;
        boolean[] side = new boolean[indices.length];
        for (int i = 0; i < indices.length; ++i) {
            float margin = hyperplaneOffset;
            for (int d = 0; d < dim; ++d) {
                margin += hyperplaneVector[d] * data.get(indices[i], d);
            }
            if (margin >= 1.0E-8f) {
                ++nLeft;
                continue;
            }
            if (margin <= -1.0E-8f) {
                side[i] = true;
                ++nRight;
                continue;
            }
            side[i] = random.nextBoolean();
            if (side[i]) {
                ++nRight;
                continue;
            }
            ++nLeft;
        }
        int[] indicesLeft = new int[nLeft];
        int[] indicesRight = new int[nRight];
        int l = 0;
        int r = 0;
        for (int i = 0; i < side.length; ++i) {
            if (side[i]) {
                indicesRight[r++] = indices[i];
                continue;
            }
            indicesLeft[l++] = indices[i];
        }
        return new Object[]{indicesLeft, indicesRight, hyperplaneVector, Float.valueOf(hyperplaneOffset)};
    }

    private static Object[] sparseAngularRandomProjectionSplit(CsrMatrix matrix, int[] indices, Random random) {
        int rightIndex;
        int leftIndex = random.nextInt(indices.length);
        if (leftIndex == (rightIndex = random.nextInt(indices.length)) && ++rightIndex == indices.length) {
            rightIndex = 0;
        }
        int left = indices[leftIndex];
        int right = indices[rightIndex];
        SparseVector leftVec = matrix.vector(left);
        SparseVector rightVec = matrix.vector(right);
        float leftNorm = leftVec.norm();
        float rightNorm = rightVec.norm();
        if (Math.abs(leftNorm) < 1.0E-8f) {
            leftNorm = 1.0f;
        }
        if (Math.abs(rightNorm) < 1.0E-8f) {
            rightNorm = 1.0f;
        }
        leftVec.divide(leftNorm);
        rightVec.divide(rightNorm);
        SparseVector sd = leftVec.subtract(rightVec);
        float hyperplaneNorm = sd.norm();
        if (Math.abs(hyperplaneNorm) < 1.0E-8f) {
            hyperplaneNorm = 1.0f;
        }
        sd.divide(hyperplaneNorm);
        int nLeft = 0;
        int nRight = 0;
        boolean[] side = new boolean[indices.length];
        for (int i = 0; i < indices.length; ++i) {
            SparseVector iVec = matrix.vector(indices[i]);
            SparseVector spm = sd.hadamardMultiply(iVec);
            float margin = spm.sum();
            if (Math.abs(margin) < 1.0E-8f) {
                side[i] = random.nextBoolean();
                if (side[i]) {
                    ++nRight;
                    continue;
                }
                ++nLeft;
                continue;
            }
            if (margin > 0.0f) {
                side[i] = false;
                ++nLeft;
                continue;
            }
            side[i] = true;
            ++nRight;
        }
        int[] indicesLeft = new int[nLeft];
        int[] indicesRight = new int[nRight];
        nLeft = 0;
        nRight = 0;
        for (int i = 0; i < side.length; ++i) {
            if (side[i]) {
                indicesRight[nRight++] = indices[i];
                continue;
            }
            indicesLeft[nLeft++] = indices[i];
        }
        Hyperplane hyperplane = new Hyperplane(sd.getIndices(), sd.getData());
        return new Object[]{indicesLeft, indicesRight, hyperplane, null};
    }

    private static Object[] sparseEuclideanRandomProjectionSplit(CsrMatrix matrix, int[] indices, Random random) {
        int rightIndex;
        int leftIndex = random.nextInt(indices.length);
        if (leftIndex == (rightIndex = random.nextInt(indices.length)) && ++rightIndex == indices.length) {
            rightIndex = 0;
        }
        int left = indices[leftIndex];
        int right = indices[rightIndex];
        SparseVector leftVec = matrix.vector(left);
        SparseVector rightVec = matrix.vector(right);
        SparseVector sd = leftVec.subtract(rightVec);
        SparseVector ss = leftVec.add(rightVec);
        ss.divide(2.0f);
        SparseVector sm = sd.hadamardMultiply(ss);
        float hyperplaneOffset = -sm.sum();
        int nLeft = 0;
        int nRight = 0;
        boolean[] side = new boolean[indices.length];
        for (int i = 0; i < indices.length; ++i) {
            SparseVector iVec = matrix.vector(indices[i]);
            SparseVector spm = sd.hadamardMultiply(iVec);
            float margin = hyperplaneOffset + spm.sum();
            if (Math.abs(margin) < 1.0E-8f) {
                side[i] = random.nextBoolean();
                if (side[i]) {
                    ++nLeft;
                    continue;
                }
                ++nRight;
                continue;
            }
            if (margin > 0.0f) {
                side[i] = false;
                ++nLeft;
                continue;
            }
            side[i] = true;
            ++nRight;
        }
        int[] indicesLeft = new int[nLeft];
        int[] indicesRight = new int[nRight];
        nLeft = 0;
        nRight = 0;
        for (int i = 0; i < side.length; ++i) {
            if (side[i]) {
                indicesRight[nRight++] = indices[i];
                continue;
            }
            indicesLeft[nLeft++] = indices[i];
        }
        Hyperplane hyperplane = new Hyperplane(sd.getIndices(), sd.getData());
        return new Object[]{indicesLeft, indicesRight, hyperplane, Float.valueOf(hyperplaneOffset)};
    }

    private static RandomProjectionTreeNode makeEuclideanTree(Matrix data, int[] indices, Random random, int leafSize) {
        if (indices.length > leafSize) {
            Object[] erps = RandomProjectionTree.euclideanRandomProjectionSplit(data, indices, random);
            int[] leftIndices = (int[])erps[0];
            int[] rightIndices = (int[])erps[1];
            Hyperplane hyperplane = new Hyperplane((float[])erps[2]);
            float offset = ((Float)erps[3]).floatValue();
            RandomProjectionTreeNode leftNode = RandomProjectionTree.makeEuclideanTree(data, leftIndices, random, leafSize);
            RandomProjectionTreeNode rightNode = RandomProjectionTree.makeEuclideanTree(data, rightIndices, random, leafSize);
            return new RandomProjectionTreeNode(null, hyperplane, Float.valueOf(offset), leftNode, rightNode);
        }
        return new RandomProjectionTreeNode(indices, null, null, null, null);
    }

    private static RandomProjectionTreeNode makeAngularTree(Matrix data, int[] indices, Random random, int leafSize) {
        if (indices.length > leafSize) {
            Object[] erps = RandomProjectionTree.angularRandomProjectionSplit(data, indices, random);
            int[] leftIndices = (int[])erps[0];
            int[] rightIndices = (int[])erps[1];
            Hyperplane hyperplane = new Hyperplane((float[])erps[2]);
            float offset = ((Float)erps[3]).floatValue();
            RandomProjectionTreeNode leftNode = RandomProjectionTree.makeAngularTree(data, leftIndices, random, leafSize);
            RandomProjectionTreeNode rightNode = RandomProjectionTree.makeAngularTree(data, rightIndices, random, leafSize);
            return new RandomProjectionTreeNode(null, hyperplane, Float.valueOf(offset), leftNode, rightNode);
        }
        return new RandomProjectionTreeNode(indices, null, null, null, null);
    }

    private static RandomProjectionTreeNode makeSparseEuclideanTree(CsrMatrix matrix, int[] indices, Random random, int leafSize) {
        if (indices.length > leafSize) {
            Object[] erps = RandomProjectionTree.sparseEuclideanRandomProjectionSplit(matrix, indices, random);
            int[] leftIndices = (int[])erps[0];
            int[] rightIndices = (int[])erps[1];
            Hyperplane hyperplane = (Hyperplane)erps[2];
            float offset = ((Float)erps[3]).floatValue();
            RandomProjectionTreeNode leftNode = RandomProjectionTree.makeSparseEuclideanTree(matrix, leftIndices, random, leafSize);
            RandomProjectionTreeNode rightNode = RandomProjectionTree.makeSparseEuclideanTree(matrix, rightIndices, random, leafSize);
            return new RandomProjectionTreeNode(null, hyperplane, Float.valueOf(offset), leftNode, rightNode);
        }
        return new RandomProjectionTreeNode(indices, null, null, null, null);
    }

    private static RandomProjectionTreeNode makeSparseAngularTree(CsrMatrix matrix, int[] indices, Random random, int leafSize) {
        if (indices.length > leafSize) {
            Object[] erps = RandomProjectionTree.sparseAngularRandomProjectionSplit(matrix, indices, random);
            int[] leftIndices = (int[])erps[0];
            int[] rightIndices = (int[])erps[1];
            Hyperplane hyperplane = (Hyperplane)erps[2];
            float offset = ((Float)erps[3]).floatValue();
            RandomProjectionTreeNode leftNode = RandomProjectionTree.makeSparseAngularTree(matrix, leftIndices, random, leafSize);
            RandomProjectionTreeNode rightNode = RandomProjectionTree.makeSparseAngularTree(matrix, rightIndices, random, leafSize);
            return new RandomProjectionTreeNode(null, hyperplane, Float.valueOf(offset), leftNode, rightNode);
        }
        return new RandomProjectionTreeNode(indices, null, null, null, null);
    }

    private static RandomProjectionTreeNode makeTree(Matrix data, Random random, int leafSize, boolean angular) {
        boolean isSparse = data instanceof CsrMatrix;
        int[] indices = MathUtils.identity(data.rows());
        if (isSparse) {
            CsrMatrix csrData = (CsrMatrix)data;
            if (angular) {
                return RandomProjectionTree.makeSparseAngularTree(csrData, indices, random, leafSize);
            }
            return RandomProjectionTree.makeSparseEuclideanTree(csrData, indices, random, leafSize);
        }
        if (angular) {
            return RandomProjectionTree.makeAngularTree(data, indices, random, leafSize);
        }
        return RandomProjectionTree.makeEuclideanTree(data, indices, random, leafSize);
    }

    static List<FlatTree> makeForest(Matrix data, int nNeighbors, int nTrees, Random random, boolean angular) {
        Random[] randoms = Utils.splitRandom(random, nTrees);
        ArrayList<FlatTree> result = new ArrayList<FlatTree>();
        int leafSize = Math.max(10, nNeighbors);
        try {
            for (int i = 0; i < nTrees; ++i) {
                result.add(RandomProjectionTree.makeTree(data, randoms[i], leafSize, angular).flatten());
                UmapProgress.update();
            }
        }
        catch (RuntimeException e) {
            Utils.message("Random Projection forest initialisation failed due to recursion limit being reached. Something is a little strange with your data, and this may take longer than normal to compute.");
            throw e;
        }
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * WARNING - void declaration
     */
    static List<FlatTree> makeForest(Matrix data, int nNeighbors, int nTrees, Random random, boolean angular, int threads) {
        if (threads == 1) {
            return RandomProjectionTree.makeForest(data, nNeighbors, nTrees, random, angular);
        }
        Random[] randoms = Utils.splitRandom(random, nTrees);
        ExecutorService executor = Executors.newFixedThreadPool(threads);
        try {
            void var12_15;
            ArrayList<Future<FlatTree>> futures = new ArrayList<Future<FlatTree>>();
            int leafSize = Math.max(10, nNeighbors);
            Random[] randomArray = randoms;
            int n = randomArray.length;
            boolean bl = false;
            while (var12_15 < n) {
                Random rand = randomArray[var12_15];
                futures.add(executor.submit(() -> RandomProjectionTree.makeTree(data, rand, leafSize, angular).flatten()));
                ++var12_15;
            }
            ArrayList<FlatTree> result = new ArrayList<FlatTree>();
            try {
                for (Future future : futures) {
                    result.add((FlatTree)future.get());
                    UmapProgress.update();
                }
            }
            catch (InterruptedException | ExecutionException ex) {
                Utils.message("Random Projection forest initialisation failed due to recursion limit being reached. Something is a little strange with your data, and this may take longer than normal to compute.");
                throw new RuntimeException(ex);
            }
            ArrayList<FlatTree> arrayList = result;
            return arrayList;
        }
        finally {
            executor.shutdown();
        }
    }
}

