package org.broadinstitute.hellbender.tools.walkers.haplotypecaller;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.util.CollectionUtil;
import htsjdk.variant.variantcontext.Allele;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.annotator.StrandOddsRatio;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.InverseAllele;
import org.broadinstitute.hellbender.utils.BaseUtils;
import org.broadinstitute.hellbender.utils.genotyper.AlleleLikelihoods;
import org.broadinstitute.hellbender.utils.haplotype.Event;
import org.broadinstitute.hellbender.utils.haplotype.Haplotype;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.jgrapht.graph.DefaultDirectedWeightedGraph;
import org.jgrapht.graph.DefaultWeightedEdge;
import org.jgrapht.io.ComponentAttributeProvider;
import org.jgrapht.io.DOTExporter;
import org.jgrapht.io.IntegerComponentNameProvider;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/haplotypecaller/AlleleFiltering.class */
public abstract class AlleleFiltering {
    protected static final Logger logger = LogManager.getLogger(AlleleFiltering.class);
    protected final AssemblyBasedCallerArgumentCollection assemblyArgs;
    private final OutputStreamWriter assemblyDebugOutStream;

    /* JADX INFO: Access modifiers changed from: package-private */
    public AlleleFiltering(AssemblyBasedCallerArgumentCollection assemblyBasedCallerArgumentCollection, OutputStreamWriter outputStreamWriter) {
        this.assemblyArgs = assemblyBasedCallerArgumentCollection;
        this.assemblyDebugOutStream = outputStreamWriter;
    }

    public AlleleLikelihoods<GATKRead, Haplotype> filterAlleles(AlleleLikelihoods<GATKRead, Haplotype> alleleLikelihoods, int i, Set<Integer> set) {
        logger.debug("SHA:: filter alleles - start");
        AlleleLikelihoods<GATKRead, Haplotype> subsetHaplotypesByAlleles = subsetHaplotypesByAlleles(alleleLikelihoods, this.assemblyArgs, i, set);
        logger.debug("SHA:: filter alleles - end");
        alleleLikelihoods.setFilteredHaplotypeCount(alleleLikelihoods.numberOfAlleles() - subsetHaplotypesByAlleles.numberOfAlleles());
        if (this.assemblyDebugOutStream != null) {
            try {
                this.assemblyDebugOutStream.write("\nThere were " + subsetHaplotypesByAlleles.alleles().size() + " haplotypes found after subsetting by alleles. Here they are:\n");
                subsetHaplotypesByAlleles.alleles().forEach(haplotype -> {
                    try {
                        this.assemblyDebugOutStream.write(haplotype.toString());
                        this.assemblyDebugOutStream.append((CharSequence) "\n");
                    } catch (IOException e) {
                        throw new UserException("Error writing to debug output stream", e);
                    }
                });
            } catch (IOException e) {
                throw new UserException("Error writing to debug output stream", e);
            }
        }
        return subsetHaplotypesByAlleles;
    }

    private AlleleLikelihoods<GATKRead, Haplotype> subsetHaplotypesByAlleles(AlleleLikelihoods<GATKRead, Haplotype> alleleLikelihoods, AssemblyBasedCallerArgumentCollection assemblyBasedCallerArgumentCollection, int i, Set<Integer> set) {
        HashSet<Haplotype> hashSet = new HashSet();
        CollectionUtil.DefaultingMap defaultingMap = new CollectionUtil.DefaultingMap(haplotype -> {
            return new ArrayList();
        }, true);
        alleleLikelihoods.alleles().forEach(haplotype2 -> {
            haplotype2.getEventMap().getEvents().stream().forEach(event -> {
                ((Collection) defaultingMap.get(haplotype2)).add(event);
            });
        });
        OccurrenceMatrix occurrenceMatrix = new OccurrenceMatrix(defaultingMap);
        List<Pair<Event, Event>> nonCoOcurringColumns = occurrenceMatrix.nonCoOcurringColumns();
        List<Pair<Event, Event>> filterByDistance = filterByDistance(nonCoOcurringColumns, 0, 3);
        List<Pair<Event, Event>> filterSameUpToHmerPairs = filterSameUpToHmerPairs(filterByDistance(nonCoOcurringColumns, 0, 20), findReferenceHaplotype(alleleLikelihoods.alleles()), i);
        filterSameUpToHmerPairs.addAll(filterByDistance);
        for (Set<Event> set2 : occurrenceMatrix.getIndependentSets(filterSameUpToHmerPairs)) {
            if (assemblyBasedCallerArgumentCollection.writeFilteringGraphs && set2.size() > 1) {
                ArrayList arrayList = new ArrayList(set2);
                HashMap hashMap = new HashMap();
                printInteractionGraph(interactionMatrixToGraph(getInteractionMatrix(arrayList, defaultingMap, alleleLikelihoods, hashMap), hashMap), hashMap, set2);
            }
            boolean z = true;
            HashSet hashSet2 = new HashSet(alleleLikelihoods.alleles());
            while (z) {
                z = false;
                logger.debug("GAL::start of iteration");
                List<Event> list = (List) hashSet2.stream().flatMap(haplotype3 -> {
                    Stream<Event> stream = haplotype3.getEventMap().getEvents().stream();
                    Objects.requireNonNull(set2);
                    return stream.filter((v1) -> {
                        return r1.contains(v1);
                    });
                }).distinct().collect(Collectors.toList());
                CollectionUtil.DefaultingMap defaultingMap2 = new CollectionUtil.DefaultingMap(event -> {
                    return new ArrayList();
                }, true);
                Stream<Haplotype> stream = alleleLikelihoods.alleles().stream();
                Objects.requireNonNull(hashSet2);
                stream.filter((v1) -> {
                    return r1.contains(v1);
                }).forEach(haplotype4 -> {
                    Stream<Event> stream2 = haplotype4.getEventMap().getEvents().stream();
                    Objects.requireNonNull(set2);
                    stream2.filter((v1) -> {
                        return r1.contains(v1);
                    }).forEach(event2 -> {
                        ((List) defaultingMap2.get(event2)).add(haplotype4);
                    });
                });
                logger.debug("AHM::printout start");
                for (Event event2 : defaultingMap2.keySet()) {
                    logger.debug("AHM::allele block ---> ");
                    for (Allele allele : (List) defaultingMap2.get(event2)) {
                        logger.debug(() -> {
                            return String.format("AHM:: (%d) %s/%s: %s", Integer.valueOf(event2.getStart()), event2.altAllele().getBaseString(), event2.refAllele().getBaseString(), allele.getBaseString());
                        });
                    }
                    logger.debug("AHM::allele block ---< ");
                }
                logger.debug("AHM::printout end");
                List list2 = (List) list.stream().map(event3 -> {
                    return getAlleleLikelihoodMatrix(alleleLikelihoods, event3, defaultingMap, hashSet2);
                }).collect(Collectors.toList());
                List<Integer> list3 = (List) IntStream.range(0, list.size()).mapToObj(i2 -> {
                    return Integer.valueOf(getAlleleLikelihoodVsInverse((AlleleLikelihoods) list2.get(i2), ((Event) list.get(i2)).altAllele()));
                }).collect(Collectors.toList());
                List<Double> list4 = (List) IntStream.range(0, list.size()).mapToObj(i3 -> {
                    return Double.valueOf(getAlleleSOR((AlleleLikelihoods) list2.get(i3), ((Event) list.get(i3)).altAllele()));
                }).collect(Collectors.toList());
                List<Event> identifyBadAlleles = identifyBadAlleles(list3, list4, list, assemblyBasedCallerArgumentCollection.prefilterQualThreshold, assemblyBasedCallerArgumentCollection.prefilterSorThreshold);
                List<Event> identifyBadAlleles2 = identifyBadAlleles(list3, list4, list, 1.0d, 2.147483647E9d);
                if (identifyBadAlleles.size() > 0 && set2.size() > 0) {
                    list.forEach(event4 -> {
                        set.add(Integer.valueOf(event4.getStart()));
                    });
                }
                if ((identifyBadAlleles.size() > 0 && list.size() > 1) || ((list.size() == 1 && identifyBadAlleles2.size() > 0) || (identifyBadAlleles.size() > 0 && this.assemblyArgs.filterLoneAlleles))) {
                    if (identifyBadAlleles2.size() > 0 && identifyBadAlleles.size() == 0) {
                        throw new GATKException.ShouldNeverReachHereException("The thresholds for stringent allele filtering should always be higher than for the relaxed one");
                    }
                    Event event5 = identifyBadAlleles.get(0);
                    logger.debug(() -> {
                        return String.format("GAL:: Remove %s", event5.toString());
                    });
                    z = true;
                    List list5 = (List) defaultingMap2.get(event5);
                    hashSet.addAll(list5);
                    hashSet2.removeAll(list5);
                }
                logger.debug("GAL::end of iteration");
            }
        }
        logger.debug("----- SHA list of removed haplotypes start ----");
        for (Haplotype haplotype5 : hashSet) {
            logger.debug(() -> {
                return String.format("SHA :: Removed haplotype : %s ", haplotype5.toString());
            });
        }
        logger.debug("----- SHA list of removed haplotypes end ----");
        HashSet<Haplotype> hashSet3 = new HashSet();
        Stream<Haplotype> filter = alleleLikelihoods.alleles().stream().filter(haplotype6 -> {
            return !hashSet.contains(haplotype6);
        });
        Objects.requireNonNull(hashSet3);
        filter.forEach((v1) -> {
            r1.add(v1);
        });
        logger.debug("----- SHA list of remaining haplotypes start ----");
        for (Haplotype haplotype7 : hashSet3) {
            logger.debug(() -> {
                return String.format("SHA :: Remaining haplotype : %s ", haplotype7.toString());
            });
        }
        logger.debug("----- SHA list of remaining haplotypes end ----");
        AlleleLikelihoods<GATKRead, Haplotype> removeAllelesToSubset = alleleLikelihoods.removeAllelesToSubset(hashSet3);
        logger.debug("----- SHA list of remaining alleles start ----");
        HashSet<Event> hashSet4 = new HashSet();
        removeAllelesToSubset.alleles().forEach(haplotype8 -> {
            Stream<Event> stream2 = haplotype8.getEventMap().getEvents().stream();
            Objects.requireNonNull(hashSet4);
            stream2.forEach((v1) -> {
                r1.add(v1);
            });
        });
        for (Event event6 : hashSet4) {
            logger.debug(() -> {
                return String.format("---- SHA :: %s ", event6.toString());
            });
        }
        logger.debug("----- SHA list of remaining alleles end ----");
        return removeAllelesToSubset;
    }

    @VisibleForTesting
    List<Event> identifyBadAlleles(List<Integer> list, List<Double> list2, List<Event> list3, double d, double d2) {
        int[] sortedIndexList = getSortedIndexList(list);
        ArrayList arrayList = new ArrayList();
        double d3 = (-1.0d) * d;
        for (int length = sortedIndexList.length - 1; length >= 0 && list.get(sortedIndexList[length]).intValue() > d3; length--) {
            arrayList.add(list3.get(sortedIndexList[length]));
        }
        int size = arrayList.size();
        logger.debug(() -> {
            return String.format("SHA:: Have %d candidates with low QUAL", Integer.valueOf(size));
        });
        for (int length2 = sortedIndexList.length - 1; length2 >= 0; length2--) {
            if (list2.get(sortedIndexList[length2]).doubleValue() > d2 && !arrayList.contains(list3.get(sortedIndexList[length2]))) {
                arrayList.add(list3.get(sortedIndexList[length2]));
            }
        }
        logger.debug(() -> {
            return String.format("SHA:: Have %d candidates with high SOR", Integer.valueOf(arrayList.size() - size));
        });
        return arrayList;
    }

    private AlleleLikelihoods<GATKRead, Allele> getAlleleLikelihoodMatrix(AlleleLikelihoods<GATKRead, Haplotype> alleleLikelihoods, Event event, Map<Haplotype, Collection<Event>> map, Set<Haplotype> set) {
        CollectionUtil.DefaultingMap defaultingMap = new CollectionUtil.DefaultingMap(allele -> {
            return new ArrayList();
        }, true);
        Allele of = InverseAllele.of(event.altAllele(), true);
        Stream<Haplotype> stream = alleleLikelihoods.alleles().stream();
        Objects.requireNonNull(set);
        Stream<Haplotype> filter = stream.filter((v1) -> {
            return r1.contains(v1);
        }).filter(haplotype -> {
            return ((Collection) map.get(haplotype)).contains(event);
        });
        List list = (List) defaultingMap.get(event.altAllele());
        Objects.requireNonNull(list);
        filter.forEach((v1) -> {
            r1.add(v1);
        });
        Stream<Haplotype> stream2 = alleleLikelihoods.alleles().stream();
        Objects.requireNonNull(set);
        Stream<Haplotype> filter2 = stream2.filter((v1) -> {
            return r1.contains(v1);
        }).filter(haplotype2 -> {
            return !((Collection) map.get(haplotype2)).contains(event);
        });
        List list2 = (List) defaultingMap.get(of);
        Objects.requireNonNull(list2);
        filter2.forEach((v1) -> {
            r1.add(v1);
        });
        AlleleLikelihoods<GATKRead, B> marginalize = alleleLikelihoods.marginalize(defaultingMap);
        logger.debug(() -> {
            return String.format("GALM: %s %d %d", event, Integer.valueOf(((List) defaultingMap.get(event.altAllele())).size()), Integer.valueOf(((List) defaultingMap.get(of)).size()));
        });
        return marginalize;
    }

    abstract int getAlleleLikelihoodVsInverse(AlleleLikelihoods<GATKRead, Allele> alleleLikelihoods, Allele allele);

    private double getAlleleSOR(AlleleLikelihoods<GATKRead, Allele> alleleLikelihoods, Allele allele) {
        int[][] contingencyTableWrtAll = StrandOddsRatio.getContingencyTableWrtAll(alleleLikelihoods, InverseAllele.of(allele, true), Collections.singletonList(allele), 1);
        double calculateSOR = StrandOddsRatio.calculateSOR(contingencyTableWrtAll);
        logger.debug(() -> {
            return String.format("GAS:: %s: %f (%d %d %d %d)", allele.toString(), Double.valueOf(calculateSOR), Integer.valueOf(contingencyTableWrtAll[0][0]), Integer.valueOf(contingencyTableWrtAll[0][1]), Integer.valueOf(contingencyTableWrtAll[1][0]), Integer.valueOf(contingencyTableWrtAll[1][1]));
        });
        return calculateSOR;
    }

    private List<Pair<Event, Event>> filterByDistance(List<Pair<Event, Event>> list, int i, int i2) {
        logger.debug(() -> {
            return String.format("FBD: input %d pairs ", Integer.valueOf(list.size()));
        });
        ArrayList arrayList = new ArrayList(list);
        arrayList.removeIf(pair -> {
            return Math.abs(((Event) pair.getLeft()).getStart() - ((Event) pair.getRight()).getStart()) > i2;
        });
        arrayList.removeIf(pair2 -> {
            return Math.abs(((Event) pair2.getLeft()).getStart() - ((Event) pair2.getRight()).getStart()) < i;
        });
        logger.debug(() -> {
            return String.format("FBD: output %d pairs ", Integer.valueOf(list.size()));
        });
        return arrayList;
    }

    private List<Pair<Event, Event>> filterSameUpToHmerPairs(List<Pair<Event, Event>> list, Haplotype haplotype, int i) {
        ArrayList arrayList = new ArrayList();
        for (Pair<Event, Event> pair : list) {
            ImmutablePair immutablePair = new ImmutablePair(haplotype.insertAllele(((Event) pair.getLeft()).refAllele(), ((Event) pair.getLeft()).altAllele(), ((Event) pair.getLeft()).getStart()), haplotype.insertAllele(((Event) pair.getRight()).refAllele(), ((Event) pair.getRight()).altAllele(), ((Event) pair.getRight()).getStart()));
            if (BaseUtils.equalUpToHmerChange(((Haplotype) immutablePair.getLeft()).getBases(), ((Haplotype) immutablePair.getRight()).getBases())) {
                arrayList.add(pair);
            }
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Haplotype findReferenceHaplotype(List<Haplotype> list) {
        for (Haplotype haplotype : list) {
            if (haplotype.isReference()) {
                return haplotype;
            }
        }
        return null;
    }

    private int getCommonPrefixLength(Allele allele, Allele allele2) {
        if (allele.length() != allele2.length()) {
            return Math.min(allele.length(), allele2.length());
        }
        return 0;
    }

    private int[] getSortedIndexList(List<Integer> list) {
        return IntStream.range(0, list.size()).mapToObj(i -> {
            return new ImmutablePair(Integer.valueOf(i), (Integer) list.get(i));
        }).sorted(Comparator.comparingInt(immutablePair -> {
            return ((Integer) immutablePair.getRight()).intValue();
        })).mapToInt(immutablePair2 -> {
            return ((Integer) immutablePair2.getLeft()).intValue();
        }).toArray();
    }

    private Event identifyStrongInteractingAllele(List<Event> list, float f, List<Event> list2, List<Integer> list3, AlleleLikelihoods<GATKRead, Haplotype> alleleLikelihoods, Map<Haplotype, Collection<Event>> map, Map<Event, List<Haplotype>> map2) {
        logger.debug("ISIA :: start");
        Map<Event, Integer> hashMap = new HashMap<>();
        IntStream.range(0, list2.size()).forEach(i -> {
            hashMap.put((Event) list2.get(i), (Integer) list3.get(i));
        });
        for (Event event : list) {
            logger.debug(() -> {
                return String.format("ISIA :: test %s", event.toString());
            });
            if (hashMap.get(event).intValue() > (-1.0f) * f) {
                logger.debug(String.format("ISIA:: selected %s due to low QUAL", event));
                return event;
            }
            if (list2.size() <= 1) {
                return null;
            }
            Map<Event, Integer> interactionVector = getInteractionVector(event, map, map2, alleleLikelihoods, hashMap);
            for (Event event2 : interactionVector.keySet()) {
                logger.debug(() -> {
                    return String.format(" --- %s: %d", event2.toString(), Integer.valueOf(((Integer) hashMap.get(event2)).intValue() - ((Integer) interactionVector.get(event2)).intValue()));
                });
                if (hashMap.get(event2).intValue() - interactionVector.get(event2).intValue() > f) {
                    logger.debug(String.format("ISIA:: selected %s", event));
                    return event;
                }
            }
        }
        logger.debug("ISIA :: end");
        return null;
    }

    private Map<Event, Map<Event, Integer>> getInteractionMatrix(List<Event> list, Map<Haplotype, Collection<Event>> map, AlleleLikelihoods<GATKRead, Haplotype> alleleLikelihoods, Map<Event, Integer> map2) {
        CollectionUtil.DefaultingMap defaultingMap = new CollectionUtil.DefaultingMap(event -> {
            return new ArrayList();
        }, true);
        HashSet hashSet = new HashSet(alleleLikelihoods.alleles());
        alleleLikelihoods.alleles().stream().forEach(haplotype -> {
            haplotype.getEventMap().getEvents().stream().filter(event2 -> {
                return list.contains(event2);
            }).forEach(event3 -> {
                ((List) defaultingMap.get(event3)).add(haplotype);
            });
        });
        ArrayList<Event> arrayList = new ArrayList(defaultingMap.keySet());
        List list2 = (List) arrayList.stream().map(event2 -> {
            return getAlleleLikelihoodMatrix(alleleLikelihoods, event2, map, hashSet);
        }).collect(Collectors.toList());
        List list3 = (List) IntStream.range(0, arrayList.size()).mapToObj(i -> {
            return Integer.valueOf(getAlleleLikelihoodVsInverse((AlleleLikelihoods) list2.get(i), ((Event) arrayList.get(i)).altAllele()));
        }).collect(Collectors.toList());
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            map2.put((Event) arrayList.get(i2), (Integer) list3.get(i2));
        }
        HashMap hashMap = new HashMap();
        for (Event event3 : arrayList) {
            hashMap.put(event3, getInteractionVector(event3, map, defaultingMap, alleleLikelihoods, map2));
        }
        return hashMap;
    }

    private Map<Event, Integer> getInteractionVector(Event event, Map<Haplotype, Collection<Event>> map, Map<Event, List<Haplotype>> map2, AlleleLikelihoods<GATKRead, Haplotype> alleleLikelihoods, Map<Event, Integer> map3) {
        List list = (List) map3.keySet().stream().filter(event2 -> {
            return event2 != event;
        }).collect(Collectors.toList());
        Set set = (Set) map.keySet().stream().filter(haplotype -> {
            return !((List) map2.get(event)).contains(haplotype);
        }).collect(Collectors.toSet());
        List list2 = (List) list.stream().map(event3 -> {
            return getAlleleLikelihoodMatrix(alleleLikelihoods, event3, map, set);
        }).collect(Collectors.toList());
        List list3 = (List) IntStream.range(0, list.size()).mapToObj(i -> {
            return Integer.valueOf(getAlleleLikelihoodVsInverse((AlleleLikelihoods) list2.get(i), ((Event) list.get(i)).altAllele()));
        }).collect(Collectors.toList());
        HashMap hashMap = new HashMap();
        IntStream.range(0, list.size()).forEach(i2 -> {
            hashMap.put((Event) list.get(i2), (Integer) list3.get(i2));
        });
        return hashMap;
    }

    private DefaultDirectedWeightedGraph<Event, DefaultWeightedEdge> interactionMatrixToGraph(Map<Event, Map<Event, Integer>> map, Map<Event, Integer> map2) {
        DefaultDirectedWeightedGraph<Event, DefaultWeightedEdge> defaultDirectedWeightedGraph = new DefaultDirectedWeightedGraph<>(DefaultWeightedEdge.class);
        map2.keySet().stream().forEach(event -> {
            defaultDirectedWeightedGraph.addVertex(event);
        });
        for (Event event2 : map.keySet()) {
            for (Event event3 : map.get(event2).keySet()) {
                int intValue = map.get(event2).get(event3).intValue() - map2.get(event3).intValue();
                if (intValue < 0) {
                    defaultDirectedWeightedGraph.setEdgeWeight((DefaultWeightedEdge) defaultDirectedWeightedGraph.addEdge(event2, event3), intValue);
                }
            }
        }
        return defaultDirectedWeightedGraph;
    }

    void printInteractionGraph(DefaultDirectedWeightedGraph<Event, DefaultWeightedEdge> defaultDirectedWeightedGraph, Map<Event, Integer> map, Set<Event> set) {
        DOTExporter dOTExporter = new DOTExporter(new IntegerComponentNameProvider(), event -> {
            return event.toString() + " = " + map.get(event);
        }, defaultWeightedEdge -> {
            return String.valueOf(defaultDirectedWeightedGraph.getEdgeWeight(defaultWeightedEdge));
        }, (ComponentAttributeProvider) null, (ComponentAttributeProvider) null);
        String contig = set.iterator().next().getContig();
        int asInt = set.stream().mapToInt(event2 -> {
            return event2.getStart();
        }).min().getAsInt();
        int asInt2 = set.stream().mapToInt(event3 -> {
            return event3.getStart();
        }).max().getAsInt();
        try {
            dOTExporter.exportGraph(defaultDirectedWeightedGraph, new FileWriter(String.format("allele.interaction.%s.%d-%d.dot", contig, Integer.valueOf(asInt), Integer.valueOf(asInt2))));
        } catch (IOException e) {
            throw new RuntimeException("Unable to write a DOT file" + String.format("allele.interaction.%s.%d-%d.dot", contig, Integer.valueOf(asInt), Integer.valueOf(asInt2)));
        }
    }
}
