package com.gengoai.apollo.math.linalg;

import com.gengoai.Validation;
import com.gengoai.stream.MStream;
import com.gengoai.stream.StreamingContext;
import com.gengoai.stream.spark.SparkStream;
import java.lang.invoke.SerializedLambda;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.SingularValueDecomposition;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:com/gengoai/apollo/math/linalg/SparkLinearAlgebra.class */
public final class SparkLinearAlgebra {
    private SparkLinearAlgebra() {
        throw new IllegalAccessError();
    }

    public static NDArray pca(RowMatrix rowMatrix, int i) {
        Validation.checkArgument(i > 0, "Number of principal components must be > 0");
        return toMatrix(rowMatrix.multiply(rowMatrix.computePrincipalComponents(i)));
    }

    public static NDArray pca(NDArray nDArray, int i) {
        Validation.checkArgument(i > 0, "Number of principal components must be > 0");
        return toMatrix(toRowMatrix(nDArray).computePrincipalComponents(i));
    }

    public static RowMatrix sparkPCA(RowMatrix rowMatrix, int i) {
        Validation.checkArgument(i > 0, "Number of principal components must be > 0");
        return rowMatrix.multiply(rowMatrix.computePrincipalComponents(i));
    }

    public static SingularValueDecomposition<RowMatrix, org.apache.spark.mllib.linalg.Matrix> sparkSVD(RowMatrix rowMatrix, int i) {
        Validation.checkArgument(i > 0, "K must be > 0");
        return rowMatrix.computeSVD(i, true, 1.0E-9d);
    }

    public static NDArray[] svd(RowMatrix rowMatrix, int i) {
        SingularValueDecomposition<RowMatrix, org.apache.spark.mllib.linalg.Matrix> sparkSVD = sparkSVD(rowMatrix, i);
        return new NDArray[]{toMatrix((RowMatrix) sparkSVD.U()), toDiagonalMatrix(sparkSVD.s()), toMatrix((org.apache.spark.mllib.linalg.Matrix) sparkSVD.V())};
    }

    public static NDArray[] svd(NDArray nDArray, int i) {
        SingularValueDecomposition<RowMatrix, org.apache.spark.mllib.linalg.Matrix> sparkSVD = sparkSVD(toRowMatrix(nDArray), i);
        return new NDArray[]{toMatrix((RowMatrix) sparkSVD.U()), toDiagonalMatrix(sparkSVD.s()), toMatrix((org.apache.spark.mllib.linalg.Matrix) sparkSVD.V())};
    }

    public static NDArray toDiagonalMatrix(Vector vector) {
        return new DenseMatrix(DoubleMatrix.diag(new DoubleMatrix(vector.toArray())));
    }

    public static NDArray toMatrix(RowMatrix rowMatrix) {
        DoubleMatrix doubleMatrix = new DoubleMatrix((int) rowMatrix.numRows(), (int) rowMatrix.numCols());
        rowMatrix.rows().toJavaRDD().zipWithIndex().toLocalIterator().forEachRemaining(tuple2 -> {
            doubleMatrix.putRow(((Long) tuple2._2()).intValue(), new DoubleMatrix(1, ((Vector) tuple2._1).size(), ((Vector) tuple2._1).toArray()));
        });
        return new DenseMatrix(doubleMatrix);
    }

    public static NDArray toMatrix(org.apache.spark.mllib.linalg.Matrix matrix) {
        return NDArrayFactory.DENSE.array(matrix.numRows(), matrix.numCols(), matrix.toArray());
    }

    public static RowMatrix toRowMatrix(NDArray nDArray) {
        return new RowMatrix(StreamingContext.distributed().range(0, nDArray.rows()).map(num -> {
            return Vectors.dense(nDArray.getRow(num.intValue()).toDoubleArray());
        }).cache().getRDD().rdd());
    }

    public static RowMatrix toRowMatrix(List<NDArray> list) {
        return new RowMatrix(StreamingContext.distributed().range(0, list.size()).map(num -> {
            return Vectors.dense(((NDArray) list.get(num.intValue())).toDoubleArray());
        }).cache().getRDD().rdd());
    }

    public static JavaRDD<Vector> toVectors(MStream<NDArray> mStream) {
        return new SparkStream(mStream).getRDD().map(nDArray -> {
            return new DenseVector(nDArray.toDoubleArray());
        }).cache();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 75932062:
                if (implMethodName.equals("lambda$toVectors$1d807a2$1")) {
                    z = true;
                    break;
                }
                break;
            case 1657880705:
                if (implMethodName.equals("lambda$toRowMatrix$ceaceb4f$1")) {
                    z = false;
                    break;
                }
                break;
            case 1763512403:
                if (implMethodName.equals("lambda$toRowMatrix$9b8b75d8$1")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/math/linalg/SparkLinearAlgebra") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/lang/Integer;)Lorg/apache/spark/mllib/linalg/Vector;")) {
                    List list = (List) serializedLambda.getCapturedArg(0);
                    return num -> {
                        return Vectors.dense(((NDArray) list.get(num.intValue())).toDoubleArray());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/math/linalg/SparkLinearAlgebra") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/math/linalg/NDArray;)Lorg/apache/spark/mllib/linalg/Vector;")) {
                    return nDArray -> {
                        return new DenseVector(nDArray.toDoubleArray());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/math/linalg/SparkLinearAlgebra") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/math/linalg/NDArray;Ljava/lang/Integer;)Lorg/apache/spark/mllib/linalg/Vector;")) {
                    NDArray nDArray2 = (NDArray) serializedLambda.getCapturedArg(0);
                    return num2 -> {
                        return Vectors.dense(nDArray2.getRow(num2.intValue()).toDoubleArray());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
