package org.broadinstitute.hellbender.tools.walkers.contamination;

import htsjdk.samtools.util.OverlapDetector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.function.DoubleUnaryOperator;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang.mutable.MutableDouble;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;
import org.apache.commons.math3.util.FastMath;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.OptimizationUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/contamination/ContaminationModel.class */
public class ContaminationModel {
    public static final double INITIAL_MAF_THRESHOLD = 0.45d;
    public static final double MAF_STEP_SIZE = 0.02d;
    private final double contamination;
    private final double errorRate;
    private final List<Double> minorAlleleFractions;
    private final List<List<PileupSummary>> segments;
    public static final int HOM_REF = 0;
    public static final int HOM_ALT = 3;
    private static final int NUM_ITERATIONS = 3;
    private static final double MIN_FRACTION_OF_SITES_TO_USE = 0.25d;
    private static final double MIN_RELATIVE_ERROR = 0.2d;
    private static final List<Double> CONTAMINATION_INITIAL_GUESSES = Arrays.asList(Double.valueOf(0.02d), Double.valueOf(0.05d), Double.valueOf(0.1d), Double.valueOf(0.2d));

    public ContaminationModel(List<PileupSummary> list) {
        this.errorRate = calculateErrorRate(list);
        this.segments = ContaminationSegmenter.findSegments(list);
        int size = this.segments.size();
        ArrayList arrayList = new ArrayList(Collections.nCopies(this.segments.size(), Double.valueOf(0.5d)));
        MutableDouble mutableDouble = new MutableDouble(0.0d);
        for (int i = 0; i < 3; i++) {
            IntStream.range(0, size).forEach(i2 -> {
            });
            Pair<List<List<PileupSummary>>, List<Double>> nonLOHSegments = getNonLOHSegments(this.segments, arrayList);
            mutableDouble.setValue(calculateContamination(this.errorRate, (List<List<PileupSummary>>) nonLOHSegments.getLeft(), (List<Double>) nonLOHSegments.getRight()));
        }
        this.minorAlleleFractions = arrayList;
        this.contamination = mutableDouble.doubleValue();
    }

    private static Pair<List<List<PileupSummary>>, List<Double>> getNonLOHSegments(List<List<PileupSummary>> list, List<Double> list2) {
        int sum = list.stream().mapToInt((v0) -> {
            return v0.size();
        }).sum();
        double d = 0.45d;
        while (true) {
            double d2 = d;
            if (d2 <= 0.0d) {
                return ImmutablePair.of(list, list2);
            }
            int[] array = IntStream.range(0, list.size()).filter(i -> {
                return ((Double) list2.get(i)).doubleValue() > d2;
            }).toArray();
            IntStream stream = Arrays.stream(array);
            list.getClass();
            List list3 = (List) stream.mapToObj(list::get).collect(Collectors.toList());
            IntStream stream2 = Arrays.stream(array);
            list2.getClass();
            List list4 = (List) stream2.mapToObj(list2::get).collect(Collectors.toList());
            if (list3.stream().mapToInt((v0) -> {
                return v0.size();
            }).sum() / sum > MIN_FRACTION_OF_SITES_TO_USE) {
                return ImmutablePair.of(list3, list4);
            }
            d = d2 - 0.02d;
        }
    }

    public Pair<Double, Double> calculateContaminationFromHoms(List<PileupSummary> list) {
        double d = 0.45d;
        while (true) {
            double d2 = d;
            if (d2 <= 0.0d) {
                return calculateContaminationFromHoms(list, 0.0d);
            }
            Pair<Double, Double> calculateContaminationFromHoms = calculateContaminationFromHoms(list, d2);
            if (((Double) calculateContaminationFromHoms.getRight()).doubleValue() < ((Double) calculateContaminationFromHoms.getLeft()).doubleValue() * 0.2d) {
                return calculateContaminationFromHoms;
            }
            d = d2 - 0.02d;
        }
    }

    private Pair<Double, Double> calculateContaminationFromHoms(List<PileupSummary> list, double d) {
        Pair<Double, Double> calculateContamination = calculateContamination(true, list, d);
        Pair<Double, Double> calculateContamination2 = calculateContamination(false, list, d);
        return ((Double) calculateContamination.getLeft()).isNaN() || ((((Double) calculateContamination.getRight()).doubleValue() > (((Double) calculateContamination.getLeft()).doubleValue() / 2.0d) ? 1 : (((Double) calculateContamination.getRight()).doubleValue() == (((Double) calculateContamination.getLeft()).doubleValue() / 2.0d) ? 0 : -1)) > 0 && (((Double) calculateContamination.getRight()).doubleValue() > ((Double) calculateContamination2.getRight()).doubleValue() ? 1 : (((Double) calculateContamination.getRight()).doubleValue() == ((Double) calculateContamination2.getRight()).doubleValue() ? 0 : -1)) > 0) ? calculateContamination2 : calculateContamination;
    }

    private Pair<Double, Double> calculateContamination(boolean z, List<PileupSummary> list, double d) {
        List<PileupSummary> subsetSites = subsetSites(list, z ? homAlts(d) : homRefs(d));
        double calculateErrorRate = calculateErrorRate(list);
        ToIntFunction toIntFunction = z ? (v0) -> {
            return v0.getRefCount();
        } : (v0) -> {
            return v0.getAltCount();
        };
        ToDoubleFunction toDoubleFunction = z ? (v0) -> {
            return v0.getRefFrequency();
        } : (v0) -> {
            return v0.getAlleleFrequency();
        };
        long sum = subsetSites.stream().mapToLong((v0) -> {
            return v0.getTotalCount();
        }).sum();
        Stream<PileupSummary> stream = subsetSites.stream();
        toIntFunction.getClass();
        long max = Math.max(stream.mapToLong((v1) -> {
            return r1.applyAsInt(v1);
        }).sum() - Math.round((sum * calculateErrorRate) / 3.0d), 0L);
        double sum2 = subsetSites.stream().mapToDouble(pileupSummary -> {
            return pileupSummary.getTotalCount() * toDoubleFunction.applyAsDouble(pileupSummary);
        }).sum();
        double d2 = max / sum2;
        return Pair.of(Double.valueOf(Math.min(d2, 1.0d)), Double.valueOf(subsetSites.isEmpty() ? 1.0d : Math.sqrt(subsetSites.stream().mapToDouble(pileupSummary2 -> {
            double totalCount = pileupSummary2.getTotalCount();
            double applyAsDouble = 1.0d - toDoubleFunction.applyAsDouble(pileupSummary2);
            return (1.0d - applyAsDouble) * totalCount * d2 * ((1.0d - d2) + (applyAsDouble * totalCount * d2));
        }).sum()) / sum2));
    }

    private List<PileupSummary> getType(int i, double d) {
        int[] array = IntStream.range(0, this.segments.size()).filter(i2 -> {
            return this.minorAlleleFractions.get(i2).doubleValue() > d;
        }).toArray();
        IntStream stream = Arrays.stream(array);
        List<List<PileupSummary>> list = this.segments;
        list.getClass();
        List list2 = (List) stream.mapToObj(list::get).collect(Collectors.toList());
        IntStream stream2 = Arrays.stream(array);
        List<Double> list3 = this.minorAlleleFractions;
        list3.getClass();
        List list4 = (List) stream2.mapToObj(list3::get).collect(Collectors.toList());
        return (List) IntStream.range(0, list2.size()).mapToObj(i3 -> {
            return ((List) list2.get(i3)).stream().filter(pileupSummary -> {
                return probability(pileupSummary, this.contamination, this.errorRate, ((Double) list4.get(i3)).doubleValue(), i) > 0.5d;
            });
        }).flatMap(stream3 -> {
            return stream3;
        }).collect(Collectors.toList());
    }

    private List<PileupSummary> homAlts(double d) {
        return getType(3, d);
    }

    private List<PileupSummary> homRefs(double d) {
        return getType(0, d);
    }

    public List<MinorAlleleFractionRecord> segmentationRecords() {
        return (List) IntStream.range(0, this.segments.size()).mapToObj(i -> {
            List<PileupSummary> list = this.segments.get(i);
            String contig = list.get(0).getContig();
            int start = list.get(0).getStart();
            int end = list.get(list.size() - 1).getEnd();
            return new MinorAlleleFractionRecord(new SimpleInterval(contig, start, end), this.minorAlleleFractions.get(i).doubleValue());
        }).collect(Collectors.toList());
    }

    private static double calculateErrorRate(List<PileupSummary> list) {
        return 1.5d * (list.stream().mapToInt((v0) -> {
            return v0.getOtherAltCount();
        }).sum() / list.stream().mapToInt((v0) -> {
            return v0.getTotalCount();
        }).sum());
    }

    private static double calculateMinorAlleleFraction(double d, double d2, List<PileupSummary> list) {
        return OptimizationUtils.max(d3 -> {
            return segmentLogLikelihood(list, d, d2, d3);
        }, 0.1d, 0.5d, 0.4d, 0.01d, 0.01d, 20).getPoint();
    }

    private static double calculateContamination(double d, List<List<PileupSummary>> list, List<Double> list2) {
        DoubleUnaryOperator doubleUnaryOperator = d2 -> {
            return modelLogLikelihood(list, d2, d, list2);
        };
        return ((UnivariatePointValuePair) Collections.max((List) CONTAMINATION_INITIAL_GUESSES.stream().map(d3 -> {
            return OptimizationUtils.max(doubleUnaryOperator, 0.0d, 0.5d, d3.doubleValue(), 1.0E-4d, 1.0E-4d, 20);
        }).collect(Collectors.toList()), Comparator.comparingDouble((v0) -> {
            return v0.getValue();
        }))).getPoint();
    }

    private static double[] genotypeLikelihoods(PileupSummary pileupSummary, double d, double d2, double d3) {
        double alleleFrequency = pileupSummary.getAlleleFrequency();
        int altCount = pileupSummary.getAltCount();
        int refCount = altCount + pileupSummary.getRefCount();
        double[] dArr = {(1.0d - alleleFrequency) * (1.0d - alleleFrequency), alleleFrequency * (1.0d - alleleFrequency), alleleFrequency * (1.0d - alleleFrequency), alleleFrequency * alleleFrequency};
        double[] dArr2 = {d2 / 3.0d, d3, 1.0d - d3, 1.0d - d2};
        return new IndexRange(0, 4).mapToDouble(i -> {
            return dArr[i] * MathUtils.binomialProbability(refCount, altCount, ((1.0d - d) * dArr2[i]) + (d * alleleFrequency));
        });
    }

    private static double probability(PileupSummary pileupSummary, double d, double d2, double d3, int i) {
        double[] genotypeLikelihoods = genotypeLikelihoods(pileupSummary, d, d2, d3);
        return genotypeLikelihoods[i] / MathUtils.sum(genotypeLikelihoods);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double segmentLogLikelihood(List<PileupSummary> list, double d, double d2, double d3) {
        return list.stream().mapToDouble(pileupSummary -> {
            return FastMath.log(MathUtils.sum(genotypeLikelihoods(pileupSummary, d, d2, d3)));
        }).sum();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double modelLogLikelihood(List<List<PileupSummary>> list, double d, double d2, List<Double> list2) {
        Utils.validate(list.size() == list2.size(), " Must have one MAF per segment");
        return new IndexRange(0, list.size()).sum(i -> {
            return segmentLogLikelihood((List) list.get(i), d, d2, ((Double) list2.get(i)).doubleValue());
        });
    }

    private static List<PileupSummary> subsetSites(List<PileupSummary> list, List<PileupSummary> list2) {
        OverlapDetector create = OverlapDetector.create(list2);
        Stream<PileupSummary> stream = list.stream();
        create.getClass();
        return (List) stream.filter((v1) -> {
            return r1.overlapsAny(v1);
        }).collect(Collectors.toList());
    }
}
