package org.broadinstitute.hellbender.utils.svd;

import java.util.Arrays;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.SingularValueDecomposition;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.broadinstitute.hellbender.tools.walkers.genotyper.StandardCallerArgumentCollection;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.spark.SparkConverter;

/* loaded from: input_file:org/broadinstitute/hellbender/utils/svd/SparkSingularValueDecomposer.class */
public final class SparkSingularValueDecomposer implements SingularValueDecomposer {
    private static final double EPS = 1.0E-32d;
    private static final Logger logger = LogManager.getLogger(SparkSingularValueDecomposer.class);
    private static final int NUM_SLICES = 60;
    private final JavaSparkContext sc;

    public SparkSingularValueDecomposer(JavaSparkContext javaSparkContext) {
        Utils.nonNull(javaSparkContext, "Cannot perform Spark MLLib SVD using a null JavaSparkContext.");
        this.sc = javaSparkContext;
    }

    @Override // org.broadinstitute.hellbender.utils.svd.SingularValueDecomposer
    public SVD createSVD(RealMatrix realMatrix) {
        Utils.nonNull(realMatrix, "Cannot perform Spark MLLib SVD on a null matrix.");
        RowMatrix convertRealMatrixToSparkRowMatrix = SparkConverter.convertRealMatrixToSparkRowMatrix(this.sc, realMatrix, 60);
        SingularValueDecomposition computeSVD = convertRealMatrixToSparkRowMatrix.computeSVD((int) convertRealMatrixToSparkRowMatrix.numCols(), true, 1.0E-9d);
        RowMatrix rowMatrix = (RowMatrix) computeSVD.U();
        Vector s = computeSVD.s();
        Matrix transpose = ((Matrix) computeSVD.V()).transpose();
        logger.info("Converting distributed Spark matrix to local matrix...");
        RealMatrix convertSparkRowMatrixToRealMatrix = SparkConverter.convertSparkRowMatrixToRealMatrix(rowMatrix, realMatrix.getRowDimension());
        logger.info("Done converting distributed Spark matrix to local matrix...");
        logger.info("Converting Spark matrix to local matrix...");
        RealMatrix convertSparkMatrixToRealMatrix = SparkConverter.convertSparkMatrixToRealMatrix(transpose);
        logger.info("Done converting Spark matrix to local matrix...");
        double[] array = s.toArray();
        logger.info("Calculating the pseudoinverse...");
        logger.info("Pinv: calculating tolerance...");
        double max = Math.max(realMatrix.getColumnDimension(), realMatrix.getRowDimension()) * realMatrix.getNorm() * EPS;
        logger.info("Pinv: inverting the singular values (with tolerance) and creating a diagonal matrix...");
        Matrix diag = Matrices.diag(Vectors.dense(Arrays.stream(array).map(d -> {
            return invertSVWithTolerance(d, max);
        }).toArray()));
        logger.info("Pinv: Multiplying V * invS * U' to get the pinv (using pinv transpose = U * invS' * V') ...");
        RowMatrix multiply = rowMatrix.multiply(diag).multiply(transpose);
        logger.info("Pinv: Converting back to local matrix ...");
        RealMatrix transpose2 = SparkConverter.convertSparkRowMatrixToRealMatrix(multiply, realMatrix.getRowDimension()).transpose();
        logger.info("Done calculating the pseudoinverse and converting it...");
        return new SimpleSVD(convertSparkRowMatrixToRealMatrix, s.toArray(), convertSparkMatrixToRealMatrix, transpose2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double invertSVWithTolerance(double d, double d2) {
        return d <= d2 ? StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION : 1.0d / d;
    }
}
