package org.broadinstitute.hellbender.tools.walkers.contamination;

import htsjdk.samtools.util.OverlapDetector;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang.mutable.MutableDouble;
import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.rank.Median;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.tools.copynumber.utils.segmentation.KernelSegmenter;
import org.broadinstitute.hellbender.tools.walkers.contamination.ContaminationRecord;
import org.broadinstitute.hellbender.tools.walkers.genotyper.StandardCallerArgumentCollection;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.OptimizationUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import picard.cmdline.programgroups.DiagnosticsAndQCProgramGroup;

@CommandLineProgramProperties(summary = "Calculate the fraction of reads coming from cross-sample contamination", oneLineSummary = "Calculate the fraction of reads coming from cross-sample contamination", programGroup = DiagnosticsAndQCProgramGroup.class)
@DocumentedFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/contamination/CalculateContamination.class */
public class CalculateContamination extends CommandLineProgram {
    private static final int MAX_CHANGEPOINTS_PER_CHROMOSOME = 10;
    private static final int MIN_SITES_PER_SEGMENT = 5;
    private static final double ALT_FRACTION_OF_DEFINITE_HOM_REF = 0.05d;
    private static final double STRICT_LOH_MAF_THRESHOLD = 0.4d;
    private static final double INITIAL_CONTAMINATION_GUESS = 0.05d;
    private static final int MAX_ITERATIONS = 10;
    private static final double CONTAMINATION_CONVERGENCE_THRESHOLD = 0.001d;
    private static final double KERNEL_SEGMENTER_LINEAR_COST = 1.0d;
    private static final double KERNEL_SEGMENTER_LOG_LINEAR_COST = 1.0d;
    private static final int KERNEL_SEGMENTER_DIMENSION = 100;
    private static final int POINTS_PER_SEGMENTATION_WINDOW = 50;
    private static final int MIN_COVERAGE = 10;
    private static final double DEFAULT_LOW_COVERAGE_RATIO_THRESHOLD = 0.5d;
    private static final double DEFAULT_HIGH_COVERAGE_RATIO_THRESHOLD = 3.0d;
    public static final int DESIRED_MINIMUM_HOM_ALT_COUNT = 50;
    public static final double MINOR_ALLELE_FRACTION_STEP_SIZE = 0.05d;

    @Argument(fullName = StandardArgumentDefinitions.INPUT_LONG_NAME, shortName = StandardArgumentDefinitions.INPUT_SHORT_NAME, doc = "The input table")
    private File inputPileupSummariesTable;
    public static final String MATCHED_NORMAL_LONG_NAME = "matched-normal";
    public static final String MATCHED_NORMAL_SHORT_NAME = "matched";
    public static final String LOW_COVERAGE_RATIO_THRESHOLD_NAME = "low-coverage-ratio-threshold";
    public static final String HIGH_COVERAGE_RATIO_THRESHOLD_NAME = "high-coverage-ratio-threshold";
    private static final double SEGMENTATION_KERNEL_VARIANCE = 0.025d;
    private static final Logger logger = LogManager.getLogger(CalculateContamination.class);
    private static final Range<Double> ALT_FRACTIONS_FOR_SEGMENTATION = Range.between(Double.valueOf(0.1d), Double.valueOf(0.9d));
    private static final BiFunction<PileupSummary, PileupSummary, Double> SEGMENTATION_KERNEL = (pileupSummary, pileupSummary2) -> {
        return Double.valueOf(FastMath.exp((-MathUtils.square(FastMath.min(pileupSummary.getAltFraction(), 1.0d - pileupSummary.getAltFraction()) - FastMath.min(pileupSummary2.getAltFraction(), 1.0d - pileupSummary2.getAltFraction()))) / 0.05d));
    };

    @Argument(fullName = MATCHED_NORMAL_LONG_NAME, shortName = MATCHED_NORMAL_SHORT_NAME, doc = "The matched normal input table", optional = true)
    private File matchedPileupSummariesTable = null;

    @Argument(fullName = "output", shortName = "O", doc = "The output table")
    private final File outputTable = null;

    @Argument(fullName = LOW_COVERAGE_RATIO_THRESHOLD_NAME, doc = "The minimum coverage relative to the median.", optional = true)
    private final double lowCoverageRatioThreshold = DEFAULT_LOW_COVERAGE_RATIO_THRESHOLD;

    @Argument(fullName = HIGH_COVERAGE_RATIO_THRESHOLD_NAME, doc = "The maximum coverage relative to the mean.", optional = true)
    private final double highCoverageRatioThreshold = DEFAULT_HIGH_COVERAGE_RATIO_THRESHOLD;

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v41, types: [java.util.List] */
    @Override // org.broadinstitute.hellbender.cmdline.CommandLineProgram
    public Object doWork() {
        List<PileupSummary> filterSites = filterSites(PileupSummary.readFromFile(this.inputPileupSummariesTable));
        List<PileupSummary> filterSites2 = this.matchedPileupSummariesTable == null ? filterSites : filterSites(PileupSummary.readFromFile(this.matchedPileupSummariesTable));
        List<List<PileupSummary>> findSegments = findSegments(filterSites2);
        ArrayList arrayList = new ArrayList();
        MutableDouble mutableDouble = new MutableDouble(0.05d);
        for (int i = 0; i < 10; i++) {
            List asList = Arrays.asList(new ArrayList());
            MutableDouble mutableDouble2 = new MutableDouble(STRICT_LOH_MAF_THRESHOLD);
            while (asList.stream().mapToInt((v0) -> {
                return v0.size();
            }).sum() < 50 && mutableDouble2.doubleValue() > StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION) {
                asList = (List) findSegments.stream().map(list -> {
                    return segmentHomAlts(list, mutableDouble.doubleValue(), mutableDouble2.doubleValue());
                }).collect(Collectors.toList());
                mutableDouble2.subtract(0.05d);
            }
            arrayList = (List) asList.stream().flatMap((v0) -> {
                return v0.stream();
            }).collect(Collectors.toList());
            double doubleValue = ((Double) calculateContamination(arrayList, errorRate(filterSites2)).getLeft()).doubleValue();
            if (Math.abs(doubleValue - mutableDouble.doubleValue()) < 0.001d) {
                break;
            }
            mutableDouble.setValue(doubleValue);
        }
        Pair<Double, Double> calculateContamination = calculateContamination(subsetSites(filterSites, arrayList), errorRate(filterSites));
        ContaminationRecord.writeToFile(Arrays.asList(new ContaminationRecord(ContaminationRecord.Level.WHOLE_BAM.toString(), ((Double) calculateContamination.getLeft()).doubleValue(), ((Double) calculateContamination.getRight()).doubleValue())), this.outputTable);
        return "SUCCESS";
    }

    private List<List<PileupSummary>> findSegments(List<PileupSummary> list) {
        return (List) ((Map) list.stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getContig();
        }))).values().stream().flatMap(list2 -> {
            return findContigSegments(list2).stream();
        }).filter(list3 -> {
            return list3.size() >= 5;
        }).collect(Collectors.toList());
    }

    private double errorRate(List<PileupSummary> list) {
        return 1.5d * (list.stream().mapToInt((v0) -> {
            return v0.getOtherAltCount();
        }).sum() / list.stream().mapToInt((v0) -> {
            return v0.getTotalCount();
        }).sum());
    }

    private static List<PileupSummary> subsetSites(List<PileupSummary> list, List<PileupSummary> list2) {
        OverlapDetector create = OverlapDetector.create(list2);
        Stream<PileupSummary> stream = list.stream();
        create.getClass();
        return (List) stream.filter((v1) -> {
            return r1.overlapsAny(v1);
        }).collect(Collectors.toList());
    }

    private List<PileupSummary> segmentHomAlts(List<PileupSummary> list, double d, double d2) {
        List<PileupSummary> likelyHetsBasedOnAlleleFraction = getLikelyHetsBasedOnAlleleFraction(list);
        double argmax = OptimizationUtils.argmax(d3 -> {
            return Double.valueOf(logLikelihoodOfHetsInSegment(likelyHetsBasedOnAlleleFraction, d3.doubleValue()));
        }, ((Double) ALT_FRACTIONS_FOR_SEGMENTATION.getMinimum()).doubleValue(), DEFAULT_LOW_COVERAGE_RATIO_THRESHOLD, STRICT_LOH_MAF_THRESHOLD, 0.01d, 0.01d, 20);
        return argmax < d2 ? Collections.emptyList() : (List) list.stream().filter(pileupSummary -> {
            return homAltProbability(pileupSummary, argmax, d) > DEFAULT_LOW_COVERAGE_RATIO_THRESHOLD;
        }).collect(Collectors.toList());
    }

    private final double logLikelihoodOfHetsInSegment(List<PileupSummary> list, double d) {
        return list.stream().mapToDouble(pileupSummary -> {
            int totalCount = pileupSummary.getTotalCount();
            int altCount = pileupSummary.getAltCount();
            return MathUtils.logSumLog(new BinomialDistribution((RandomGenerator) null, totalCount, d).logProbability(altCount), new BinomialDistribution((RandomGenerator) null, totalCount, 1.0d - d).logProbability(altCount)) + MathUtils.LOG_ONE_HALF;
        }).sum();
    }

    private List<List<PileupSummary>> findContigSegments(List<PileupSummary> list) {
        List<PileupSummary> likelyHetsBasedOnAlleleFraction = getLikelyHetsBasedOnAlleleFraction(list);
        if (likelyHetsBasedOnAlleleFraction.isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(-1);
        arrayList.addAll(new KernelSegmenter(likelyHetsBasedOnAlleleFraction).findChangepoints(10, SEGMENTATION_KERNEL, 100, Arrays.asList(50), 1.0d, 1.0d, KernelSegmenter.ChangepointSortOrder.INDEX));
        arrayList.add(Integer.valueOf(likelyHetsBasedOnAlleleFraction.size() - 1));
        List list2 = (List) IntStream.range(0, arrayList.size() - 1).mapToObj(i -> {
            PileupSummary pileupSummary = (PileupSummary) likelyHetsBasedOnAlleleFraction.get(((Integer) arrayList.get(i)).intValue() + 1);
            return new SimpleInterval(pileupSummary.getContig(), pileupSummary.getStart(), ((PileupSummary) likelyHetsBasedOnAlleleFraction.get(((Integer) arrayList.get(i + 1)).intValue())).getEnd());
        }).collect(Collectors.toList());
        OverlapDetector create = OverlapDetector.create(list);
        return (List) list2.stream().map(simpleInterval -> {
            return (List) create.getOverlaps(simpleInterval).stream().sorted(Comparator.comparingInt((v0) -> {
                return v0.getStart();
            })).collect(Collectors.toList());
        }).collect(Collectors.toList());
    }

    private List<PileupSummary> getLikelyHetsBasedOnAlleleFraction(List<PileupSummary> list) {
        return (List) list.stream().filter(pileupSummary -> {
            return ALT_FRACTIONS_FOR_SEGMENTATION.contains(Double.valueOf(pileupSummary.getAltFraction()));
        }).collect(Collectors.toList());
    }

    private static Pair<Double, Double> calculateContamination(List<PileupSummary> list, double d) {
        if (list.isEmpty()) {
            logger.warn("No hom alt sites found!  Perhaps GetPileupSummaries was run on too small of an interval, or perhaps the sample was extremely inbred or haploid.");
            return Pair.of(Double.valueOf(StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION), Double.valueOf(1.0d));
        }
        long sum = list.stream().mapToLong((v0) -> {
            return v0.getTotalCount();
        }).sum();
        long sum2 = list.stream().mapToLong((v0) -> {
            return v0.getRefCount();
        }).sum();
        long round = Math.round((sum * d) / DEFAULT_HIGH_COVERAGE_RATIO_THRESHOLD);
        long max = Math.max(sum2 - round, 0L);
        double sum3 = list.stream().mapToDouble(pileupSummary -> {
            return pileupSummary.getTotalCount() * (1.0d - pileupSummary.getAlleleFrequency());
        }).sum();
        double d2 = max / sum3;
        double sqrt = Math.sqrt(d2 / sum3);
        logger.info(String.format("In %d homozygous variant sites we find %d reference reads due to contamination and %d due to to sequencing error out of a total %d reads.", Integer.valueOf(list.size()), Long.valueOf(max), Long.valueOf(round), Long.valueOf(sum)));
        logger.info(String.format("Based on population data, we would expect %d reference reads in a contaminant with equal depths at these sites.", Long.valueOf((long) sum3)));
        logger.info(String.format("Therefore, we estimate a contamination of %.3f.", Double.valueOf(d2)));
        logger.info(String.format("The error bars on this estimate are %.5f.", Double.valueOf(sqrt)));
        return Pair.of(Double.valueOf(d2), Double.valueOf(sqrt));
    }

    private List<PileupSummary> filterSites(List<PileupSummary> list) {
        List list2 = (List) list.stream().filter(pileupSummary -> {
            return pileupSummary.getTotalCount() > 10;
        }).collect(Collectors.toList());
        double[] array = list2.stream().mapToDouble((v0) -> {
            return v0.getTotalCount();
        }).toArray();
        double evaluate = new Median().evaluate(array);
        double evaluate2 = new Mean().evaluate(array);
        double d = evaluate * DEFAULT_LOW_COVERAGE_RATIO_THRESHOLD;
        double d2 = evaluate2 * DEFAULT_HIGH_COVERAGE_RATIO_THRESHOLD;
        return (List) list2.stream().filter(pileupSummary2 -> {
            return ((double) pileupSummary2.getTotalCount()) > d && ((double) pileupSummary2.getTotalCount()) < d2;
        }).filter(pileupSummary3 -> {
            return pileupSummary3.getAltFraction() > 0.05d;
        }).collect(Collectors.toList());
    }

    private double homAltProbability(PileupSummary pileupSummary, double d, double d2) {
        double alleleFrequency = pileupSummary.getAlleleFrequency();
        double square = MathUtils.square(alleleFrequency);
        double d3 = 2.0d * alleleFrequency * (1.0d - alleleFrequency);
        int altCount = pileupSummary.getAltCount();
        int refCount = altCount + pileupSummary.getRefCount();
        if (altCount < refCount / 2) {
            return StandardCallerArgumentCollection.DEFAULT_CONTAMINATION_FRACTION;
        }
        double probability = new BinomialDistribution((RandomGenerator) null, refCount, 1.0d - d2).probability(altCount);
        double probability2 = new BinomialDistribution((RandomGenerator) null, refCount, 1.0d - d).probability(altCount);
        double d4 = square * probability;
        return d4 / ((d3 * probability2) + d4);
    }
}
