package org.broadinstitute.hellbender.tools.copynumber.denoising;

import com.google.common.primitives.Doubles;
import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.IntStream;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.descriptive.rank.Median;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.copynumber.CreateReadCountPanelOfNormals;
import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.walkers.genotyper.StandardCallerArgumentCollection;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.MatrixSummaryUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/denoising/SVDDenoisingUtils.class */
public final class SVDDenoisingUtils {
    private static final Logger logger = LogManager.getLogger(SVDDenoisingUtils.class);
    private static final double INV_LN2 = MathUtils.INV_LOG_2;
    private static final double EPSILON = 1.0E-9d;
    private static final double LN2_EPSILON = Math.log(EPSILON) * INV_LN2;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/copynumber/denoising/SVDDenoisingUtils$PreprocessedStandardizedResult.class */
    public static final class PreprocessedStandardizedResult {
        final RealMatrix preprocessedStandardizedValues;
        final double[] panelIntervalFractionalMedians;
        final boolean[] filterSamples;
        final boolean[] filterIntervals;

        private PreprocessedStandardizedResult(RealMatrix realMatrix, double[] dArr, boolean[] zArr, boolean[] zArr2) {
            this.preprocessedStandardizedValues = realMatrix;
            this.panelIntervalFractionalMedians = dArr;
            this.filterSamples = zArr;
            this.filterIntervals = zArr2;
        }
    }

    private SVDDenoisingUtils() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static PreprocessedStandardizedResult preprocessAndStandardizePanel(RealMatrix realMatrix, double[] dArr, double d, double d2, double d3, double d4, boolean z, double d5) {
        logger.info("Preprocessing read counts...");
        PreprocessedStandardizedResult preprocessPanel = preprocessPanel(realMatrix, dArr, d, d2, d3, d4, z, d5);
        logger.info("Panel read counts preprocessed.");
        logger.info("Standardizing read counts...");
        divideBySampleMedianAndTransformToLog2(preprocessPanel.preprocessedStandardizedValues);
        logger.info("Subtracting median of sample medians...");
        final double evaluate = new Median().evaluate(MatrixSummaryUtils.getRowMedians(preprocessPanel.preprocessedStandardizedValues));
        preprocessPanel.preprocessedStandardizedValues.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils.1
            public double visit(int i, int i2, double d6) {
                return d6 - evaluate;
            }
        });
        logger.info("Panel read counts standardized.");
        return preprocessPanel;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static SVDDenoisedCopyRatioResult denoise(SVDReadCountPanelOfNormals sVDReadCountPanelOfNormals, SimpleCountCollection simpleCountCollection, int i) {
        RealMatrix subtractProjection;
        Utils.nonNull(sVDReadCountPanelOfNormals);
        if (!CopyNumberArgumentValidationUtils.isSameDictionary(sVDReadCountPanelOfNormals.getSequenceDictionary(), ((SampleLocatableMetadata) simpleCountCollection.getMetadata()).getSequenceDictionary())) {
            logger.warn("Sequence dictionaries in panel and case sample do not match.");
        }
        ParamUtils.isPositive(i, "Number of eigensamples to use for denoising must be positive.");
        Utils.validateArg(i <= sVDReadCountPanelOfNormals.getNumEigensamples(), "Number of eigensamples to use for denoising is greater than the number available in the panel of normals.");
        logger.info("Validating sample intervals against original intervals used to build panel of normals...");
        Utils.validateArg(sVDReadCountPanelOfNormals.getOriginalIntervals().equals(simpleCountCollection.getIntervals()), "Sample intervals must be identical to the original intervals used to build the panel of normals.");
        logger.info("Preprocessing and standardizing sample read counts...");
        RealMatrix preprocessAndStandardizeSample = preprocessAndStandardizeSample(sVDReadCountPanelOfNormals, simpleCountCollection.getCounts());
        logger.info(String.format("Using %d out of %d eigensamples to denoise...", Integer.valueOf(i), Integer.valueOf(sVDReadCountPanelOfNormals.getNumEigensamples())));
        if (sVDReadCountPanelOfNormals.getOriginalReadCounts().length == 1) {
            logger.warn("Only a single sample was used to build the panel of normals, not cannot perform denoising...");
            subtractProjection = preprocessAndStandardizeSample;
        } else {
            logger.info("Subtracting projection onto space spanned by eigensamples...");
            subtractProjection = subtractProjection(preprocessAndStandardizeSample, sVDReadCountPanelOfNormals.getEigensampleVectors(), i);
        }
        logger.info("Sample denoised.");
        return new SVDDenoisedCopyRatioResult((SampleLocatableMetadata) simpleCountCollection.getMetadata(), sVDReadCountPanelOfNormals.getPanelIntervals(), preprocessAndStandardizeSample, subtractProjection);
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    public static RealMatrix preprocessAndStandardizeSample(double[] dArr, double[] dArr2) {
        Utils.nonNull(dArr);
        Utils.validateArg(dArr2 == null || dArr.length == dArr2.length, "Number of intervals for read counts must match those for GC-content annotations.");
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix((double[][]) new double[]{dArr});
        logger.info("Preprocessing read counts...");
        transformToFractionalCoverage(array2DRowRealMatrix);
        performOptionalGCBiasCorrection(array2DRowRealMatrix, dArr2);
        logger.info("Sample read counts preprocessed.");
        logger.info("Standardizing read counts...");
        divideBySampleMedianAndTransformToLog2(array2DRowRealMatrix);
        logger.info("Subtracting sample median...");
        final double[] rowMedians = MatrixSummaryUtils.getRowMedians(array2DRowRealMatrix);
        array2DRowRealMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils.2
            public double visit(int i, int i2, double d) {
                return d - rowMedians[i];
            }
        });
        logger.info("Sample read counts standardized.");
        return array2DRowRealMatrix;
    }

    private static PreprocessedStandardizedResult preprocessPanel(RealMatrix realMatrix, double[] dArr, double d, double d2, double d3, double d4, boolean z, double d5) {
        transformToFractionalCoverage(realMatrix);
        performOptionalGCBiasCorrection(realMatrix, dArr);
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        boolean[] zArr = new boolean[rowDimension];
        boolean[] zArr2 = new boolean[columnDimension];
        double[] columnMedians = MatrixSummaryUtils.getColumnMedians(realMatrix);
        if (d == StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION) {
            logger.info(String.format("A value of 0 was provided for argument %s, so the corresponding filtering step will be skipped...", CreateReadCountPanelOfNormals.MINIMUM_INTERVAL_MEDIAN_PERCENTILE_LONG_NAME));
        } else {
            logger.info(String.format("Filtering intervals with median (across samples) below the %.2f percentile...", Double.valueOf(d)));
            double evaluate = new Percentile(d).evaluate(columnMedians);
            IntStream.range(0, columnDimension).filter(i -> {
                return columnMedians[i] < evaluate;
            }).forEach(i2 -> {
                zArr2[i2] = true;
            });
            logger.info(String.format("After filtering, %d out of %d intervals remain...", Integer.valueOf(countNumberPassingFilter(zArr2)), Integer.valueOf(columnDimension)));
        }
        logger.info("Dividing by interval medians...");
        IntStream.range(0, columnDimension).filter(i3 -> {
            return !zArr2[i3];
        }).forEach(i4 -> {
            IntStream.range(0, rowDimension).filter(i4 -> {
                return !zArr[i4];
            }).forEach(i5 -> {
                realMatrix.setEntry(i5, i4, realMatrix.getEntry(i5, i4) / columnMedians[i4]);
            });
        });
        if (d2 == 100.0d) {
            logger.info(String.format("A value of 100 was provided for argument %s, so the corresponding filtering step will be skipped...", CreateReadCountPanelOfNormals.MAXIMUM_ZEROS_IN_SAMPLE_PERCENTAGE_LONG_NAME));
        } else {
            logger.info(String.format("Filtering samples with a fraction of zero-coverage intervals above %.2f percent...", Double.valueOf(d2)));
            int calculateMaximumZerosCount = calculateMaximumZerosCount(countNumberPassingFilter(zArr2), d2);
            IntStream.range(0, rowDimension).filter(i5 -> {
                return !zArr[i5];
            }).forEach(i6 -> {
                if (((int) IntStream.range(0, columnDimension).filter(i6 -> {
                    return !zArr2[i6] && realMatrix.getEntry(i6, i6) == StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION;
                }).count()) > calculateMaximumZerosCount) {
                    zArr[i6] = true;
                }
            });
            logger.info(String.format("After filtering, %d out of %d samples remain...", Integer.valueOf(countNumberPassingFilter(zArr)), Integer.valueOf(rowDimension)));
        }
        if (d3 == 100.0d) {
            logger.info(String.format("A value of 100 was provided for argument %s, so the corresponding filtering step will be skipped...", CreateReadCountPanelOfNormals.MAXIMUM_ZEROS_IN_INTERVAL_PERCENTAGE_LONG_NAME));
        } else {
            logger.info(String.format("Filtering intervals with a fraction of zero-coverage samples above %.2f percent...", Double.valueOf(d3)));
            int calculateMaximumZerosCount2 = calculateMaximumZerosCount(countNumberPassingFilter(zArr), d3);
            IntStream.range(0, columnDimension).filter(i7 -> {
                return !zArr2[i7];
            }).forEach(i8 -> {
                if (((int) IntStream.range(0, rowDimension).filter(i8 -> {
                    return !zArr[i8] && realMatrix.getEntry(i8, i8) == StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION;
                }).count()) > calculateMaximumZerosCount2) {
                    zArr2[i8] = true;
                }
            });
            logger.info(String.format("After filtering, %d out of %d intervals remain...", Integer.valueOf(countNumberPassingFilter(zArr2)), Integer.valueOf(columnDimension)));
        }
        if (d4 == StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION) {
            logger.info(String.format("A value of 0 was provided for argument %s, so the corresponding filtering step will be skipped...", CreateReadCountPanelOfNormals.EXTREME_SAMPLE_MEDIAN_PERCENTILE_LONG_NAME));
        } else {
            logger.info(String.format("Filtering samples with a median (across intervals) below the %.2f percentile or above the %.2f percentile...", Double.valueOf(d4), Double.valueOf(100.0d - d4)));
            double[] array = IntStream.range(0, rowDimension).mapToDouble(i9 -> {
                return new Median().evaluate(IntStream.range(0, columnDimension).filter(i9 -> {
                    return !zArr2[i9];
                }).mapToDouble(i10 -> {
                    return realMatrix.getEntry(i9, i10);
                }).toArray());
            }).toArray();
            double evaluate2 = new Percentile(d4).evaluate(array);
            double evaluate3 = new Percentile(100.0d - d4).evaluate(array);
            IntStream.range(0, rowDimension).filter(i10 -> {
                return array[i10] < evaluate2 || array[i10] > evaluate3;
            }).forEach(i11 -> {
                zArr[i11] = true;
            });
            logger.info(String.format("After filtering, %d out of %d samples remain...", Integer.valueOf(countNumberPassingFilter(zArr)), Integer.valueOf(rowDimension)));
        }
        int[] array2 = IntStream.range(0, columnDimension).filter(i12 -> {
            return !zArr2[i12];
        }).toArray();
        RealMatrix subMatrix = realMatrix.getSubMatrix(IntStream.range(0, rowDimension).filter(i13 -> {
            return !zArr[i13];
        }).toArray(), array2);
        double[] array3 = IntStream.range(0, columnDimension).filter(i14 -> {
            return !zArr2[i14];
        }).mapToDouble(i15 -> {
            return columnMedians[i15];
        }).toArray();
        logHeapUsage();
        logger.info("Performing garbage collection...");
        System.gc();
        logHeapUsage();
        if (z) {
            final double[] array4 = IntStream.range(0, array2.length).mapToObj(i16 -> {
                return Arrays.stream(subMatrix.getColumn(i16)).filter(d6 -> {
                    return d6 > StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION;
                }).toArray();
            }).mapToDouble(dArr2 -> {
                return new Median().evaluate(dArr2);
            }).toArray();
            final int[] iArr = {0};
            subMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils.3
                public double visit(int i17, int i18, double d6) {
                    if (d6 != StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION) {
                        return d6;
                    }
                    int[] iArr2 = iArr;
                    iArr2[0] = iArr2[0] + 1;
                    return array4[i18];
                }
            });
            logger.info(String.format("%d zero-coverage values were imputed to the median of the non-zero values in the corresponding interval...", Integer.valueOf(iArr[0])));
        } else {
            logger.info("Skipping imputation of zero-coverage values...");
        }
        if (d5 == StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION) {
            logger.info(String.format("A value of 0 was provided for argument %s, so the corresponding truncation step will be skipped...", CreateReadCountPanelOfNormals.EXTREME_OUTLIER_TRUNCATION_PERCENTILE_LONG_NAME));
        } else {
            double[] concat = Doubles.concat(subMatrix.getData());
            final double evaluate4 = new Percentile(d5).evaluate(concat);
            final double evaluate5 = new Percentile(100.0d - d5).evaluate(concat);
            final int[] iArr2 = {0};
            subMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils.4
                public double visit(int i17, int i18, double d6) {
                    if (d6 < evaluate4) {
                        int[] iArr3 = iArr2;
                        iArr3[0] = iArr3[0] + 1;
                        return evaluate4;
                    }
                    if (d6 <= evaluate5) {
                        return d6;
                    }
                    int[] iArr4 = iArr2;
                    iArr4[0] = iArr4[0] + 1;
                    return evaluate5;
                }
            });
            logger.info(String.format("%d values below the %.2f percentile or above the %.2f percentile were truncated to the corresponding value...", Integer.valueOf(iArr2[0]), Double.valueOf(d5), Double.valueOf(100.0d - d5)));
        }
        return new PreprocessedStandardizedResult(subMatrix, array3, zArr, zArr2);
    }

    private static void logHeapUsage() {
        Runtime runtime = Runtime.getRuntime();
        logger.info("Heap utilization statistics [MB]:");
        logger.info("Used memory: " + ((runtime.totalMemory() - runtime.freeMemory()) / 1048576));
        logger.info("Free memory: " + (runtime.freeMemory() / 1048576));
        logger.info("Total memory: " + (runtime.totalMemory() / 1048576));
        logger.info("Maximum memory: " + (runtime.maxMemory() / 1048576));
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    private static RealMatrix preprocessAndStandardizeSample(SVDReadCountPanelOfNormals sVDReadCountPanelOfNormals, double[] dArr) {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix((double[][]) new double[]{dArr});
        logger.info("Preprocessing read counts...");
        transformToFractionalCoverage(array2DRowRealMatrix);
        performOptionalGCBiasCorrection(array2DRowRealMatrix, sVDReadCountPanelOfNormals.getOriginalIntervalGCContent());
        logger.info("Subsetting sample intervals to post-filter panel intervals...");
        HashSet hashSet = new HashSet(sVDReadCountPanelOfNormals.getPanelIntervals());
        RealMatrix subMatrix = array2DRowRealMatrix.getSubMatrix(new int[]{0}, IntStream.range(0, sVDReadCountPanelOfNormals.getOriginalIntervals().size()).filter(i -> {
            return hashSet.contains(sVDReadCountPanelOfNormals.getOriginalIntervals().get(i));
        }).toArray());
        logger.info("Dividing by interval medians from the panel of normals...");
        final double[] panelIntervalFractionalMedians = sVDReadCountPanelOfNormals.getPanelIntervalFractionalMedians();
        subMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils.5
            public double visit(int i2, int i3, double d) {
                return d / panelIntervalFractionalMedians[i3];
            }
        });
        logger.info("Sample read counts preprocessed.");
        logger.info("Standardizing read counts...");
        divideBySampleMedianAndTransformToLog2(subMatrix);
        logger.info("Subtracting sample median...");
        final double[] rowMedians = MatrixSummaryUtils.getRowMedians(subMatrix);
        subMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils.6
            public double visit(int i2, int i3, double d) {
                return d - rowMedians[i2];
            }
        });
        logger.info("Sample read counts standardized.");
        return subMatrix;
    }

    private static RealMatrix subtractProjection(RealMatrix realMatrix, double[][] dArr, int i) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        logger.info("Distributing the standardized read counts...");
        logger.info("Composing eigensample matrix for the requested number of eigensamples and transposing them...");
        Array2DRowRealMatrix array2DRowRealMatrix = i == length2 ? new Array2DRowRealMatrix(dArr, false) : new Array2DRowRealMatrix(dArr, false).getSubMatrix(0, length - 1, 0, i - 1);
        logger.info("Computing projection...");
        RealMatrix multiply = realMatrix.multiply(array2DRowRealMatrix).multiply(array2DRowRealMatrix.transpose());
        logger.info("Subtracting projection...");
        return realMatrix.subtract(multiply);
    }

    private static int countNumberPassingFilter(boolean[] zArr) {
        int count = (int) IntStream.range(0, zArr.length).filter(i -> {
            return !zArr[i];
        }).count();
        if (count == 0) {
            throw new UserException.BadInput("Filtering removed all samples or intervals.  Select less strict filtering criteria.");
        }
        return count;
    }

    private static void transformToFractionalCoverage(RealMatrix realMatrix) {
        logger.info("Transforming read counts to fractional coverage...");
        final double[] rowSums = MathUtils.rowSums(realMatrix);
        realMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils.7
            public double visit(int i, int i2, double d) {
                return d / rowSums[i];
            }
        });
    }

    private static void performOptionalGCBiasCorrection(RealMatrix realMatrix, double[] dArr) {
        if (dArr != null) {
            logger.info("Performing GC-bias correction...");
            GCBiasCorrector.correctGCBias(realMatrix, dArr);
        }
    }

    private static void divideBySampleMedianAndTransformToLog2(RealMatrix realMatrix) {
        logger.info("Dividing by sample medians and transforming to log2 space...");
        final double[] rowMedians = MatrixSummaryUtils.getRowMedians(realMatrix);
        IntStream.range(0, rowMedians.length).forEach(i -> {
            ParamUtils.isPositive(rowMedians[i], rowMedians.length == 1 ? "Sample does not have a non-negative sample median." : String.format("Sample at index %s does not have a non-negative sample median.", Integer.valueOf(i)));
        });
        realMatrix.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { // from class: org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils.8
            public double visit(int i2, int i3, double d) {
                return SVDDenoisingUtils.safeLog2(d / rowMedians[i2]);
            }
        });
    }

    private static int calculateMaximumZerosCount(int i, double d) {
        return (int) Math.ceil((i * d) / 100.0d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double safeLog2(double d) {
        return d < EPSILON ? LN2_EPSILON : Math.log(d) * INV_LN2;
    }
}
