package org.broadinstitute.hellbender.tools.walkers.readorientation;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.metrics.Header;
import htsjdk.samtools.metrics.MetricsFile;
import htsjdk.samtools.util.CloserUtil;
import htsjdk.samtools.util.Histogram;
import htsjdk.samtools.util.IOUtil;
import htsjdk.samtools.util.SequenceUtil;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.Nucleotide;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.io.IOUtils;

@CommandLineProgramProperties(summary = "Get the maximum likelihood estimates of artifact prior probabilities in the orientation bias mixture model filter", oneLineSummary = "Get the maximum likelihood estimates of artifact prior probabilities in the orientation bias mixture model filter", programGroup = ShortVariantDiscoveryProgramGroup.class)
@DocumentedFeature
/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/readorientation/LearnReadOrientationModel.class */
public class LearnReadOrientationModel extends CommandLineProgram {
    public static final double DEFAULT_CONVERGENCE_THRESHOLD = 1.0E-4d;
    public static final int DEFAULT_MAX_ITERATIONS = 20;
    private static final int DEFAULT_INITIAL_LIST_SIZE = 1000000;
    public static final String EM_CONVERGENCE_THRESHOLD_LONG_NAME = "convergence-threshold";
    public static final String MAX_EM_ITERATIONS_LONG_NAME = "num-em-iterations";
    public static final String MAX_DEPTH_LONG_NAME = "max-depth";
    public static final String ARTIFACT_PRIOR_EXTENSION = ".orientation_priors";

    @Argument(fullName = StandardArgumentDefinitions.INPUT_LONG_NAME, shortName = StandardArgumentDefinitions.INPUT_SHORT_NAME, doc = "One or more .tar.gz containing outputs of CollectF1R2Counts")
    private List<File> inputTarGzs;

    @Argument(fullName = "output", shortName = "O", doc = "tar.gz of artifact prior tables")
    private File outputTarGz;

    @Argument(fullName = EM_CONVERGENCE_THRESHOLD_LONG_NAME, doc = "Stop the EM when the distance between parameters between iterations falls below this value", optional = true)
    private double convergenceThreshold = 1.0E-4d;

    @Argument(fullName = MAX_EM_ITERATIONS_LONG_NAME, doc = "give up on EM after this many iterations", optional = true)
    private int maxEMIterations = 20;

    @Argument(fullName = MAX_DEPTH_LONG_NAME, doc = "sites with depth higher than this value will be grouped", optional = true)
    private int maxDepth = 200;
    private Map<String, List<Histogram<Integer>>> refHistogramsBySample;
    private Map<String, List<Histogram<Integer>>> altHistogramsBySample;

    @Override // org.broadinstitute.hellbender.cmdline.CommandLineProgram
    public Object doWork() {
        if (!this.outputTarGz.getAbsolutePath().endsWith(".tar.gz")) {
            throw new UserException.CouldNotCreateOutputFile(this.outputTarGz, "Output file must end in .tar.gz");
        }
        List list = (List) IntStream.range(0, this.inputTarGzs.size()).mapToObj(i -> {
            return IOUtils.createTempDir(Integer.toString(i));
        }).collect(Collectors.toList());
        IntStream.range(0, this.inputTarGzs.size()).forEach(i2 -> {
            IOUtils.extractTarGz(this.inputTarGzs.get(i2).toPath(), ((File) list.get(i2)).toPath());
        });
        List list2 = (List) list.stream().flatMap(file -> {
            return F1R2CountsCollector.getRefHistogramsFromExtractedTar(file).stream();
        }).collect(Collectors.toList());
        List list3 = (List) list.stream().flatMap(file2 -> {
            return F1R2CountsCollector.getAltHistogramsFromExtractedTar(file2).stream();
        }).collect(Collectors.toList());
        List list4 = (List) list.stream().flatMap(file3 -> {
            return F1R2CountsCollector.getAltTablesFromExtractedTar(file3).stream();
        }).collect(Collectors.toList());
        Map map = (Map) list2.stream().map(file4 -> {
            return readMetricsFile(file4);
        }).collect(Collectors.groupingBy(metricsFile -> {
            return ((Header) metricsFile.getHeaders().get(0)).toString();
        }));
        Map map2 = (Map) list3.stream().map(file5 -> {
            return readMetricsFile(file5);
        }).collect(Collectors.groupingBy(metricsFile2 -> {
            return ((Header) metricsFile2.getHeaders().get(0)).toString();
        }));
        Set keySet = map.keySet();
        Set keySet2 = map2.keySet();
        Utils.validate(keySet2.isEmpty() || (keySet.containsAll(keySet2) && keySet2.containsAll(keySet)), "ref and alt histograms must have same samples");
        Utils.validate(keySet2.isEmpty() || keySet.stream().allMatch(str -> {
            return ((List) map.get(str)).size() == ((List) map2.get(str)).size();
        }), "Each sample must have the same number of alt and ref histograms");
        this.refHistogramsBySample = (Map) keySet.stream().collect(Collectors.toMap(str2 -> {
            return str2;
        }, str3 -> {
            return sumHistogramsFromFiles((List) map.get(str3), true);
        }));
        this.altHistogramsBySample = (Map) keySet2.stream().collect(Collectors.toMap(str4 -> {
            return str4;
        }, str5 -> {
            return sumHistogramsFromFiles((List) map2.get(str5), false);
        }));
        Map<String, List<AltSiteRecord>> gatherAltSiteRecords = gatherAltSiteRecords(list4);
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, List<AltSiteRecord>> entry : gatherAltSiteRecords.entrySet()) {
            String key = entry.getKey();
            Map map3 = (Map) entry.getValue().stream().collect(Collectors.groupingBy((v0) -> {
                return v0.getReferenceContext();
            }));
            ArtifactPriorCollection artifactPriorCollection = new ArtifactPriorCollection(key);
            for (String str6 : F1R2FilterConstants.CANONICAL_KMERS) {
                String reverseComplement = SequenceUtil.reverseComplement(str6);
                Histogram<Integer> combineRefHistogramWithRC = combineRefHistogramWithRC(str6, this.refHistogramsBySample.get(key).stream().filter(histogram -> {
                    return histogram.getValueLabel().equals(str6);
                }).findFirst().orElseGet(() -> {
                    return F1R2FilterUtils.createRefHistogram(str6, this.maxDepth);
                }), this.refHistogramsBySample.get(key).stream().filter(histogram2 -> {
                    return histogram2.getValueLabel().equals(reverseComplement);
                }).findFirst().orElseGet(() -> {
                    return F1R2FilterUtils.createRefHistogram(reverseComplement, this.maxDepth);
                }), this.maxDepth);
                List<Histogram<Integer>> combineAltDepthOneHistogramWithRC = combineAltDepthOneHistogramWithRC(!this.altHistogramsBySample.containsKey(key) ? Collections.emptyList() : (List) this.altHistogramsBySample.get(key).stream().filter(histogram3 -> {
                    return histogram3.getValueLabel().startsWith(str6);
                }).collect(Collectors.toList()), !this.altHistogramsBySample.containsKey(key) ? Collections.emptyList() : (List) this.altHistogramsBySample.get(key).stream().filter(histogram4 -> {
                    return histogram4.getValueLabel().startsWith(reverseComplement);
                }).collect(Collectors.toList()), this.maxDepth);
                List list5 = (List) map3.getOrDefault(str6, new ArrayList());
                mergeDesignMatrices(list5, (List) map3.getOrDefault(reverseComplement, Collections.emptyList()));
                if (combineRefHistogramWithRC.getSumOfValues() == 0.0d || list5.isEmpty()) {
                    this.logger.info(String.format("Skipping the reference context %s as we didn't find either the ref or alt table for the context", str6));
                } else {
                    artifactPriorCollection.set(new LearnReadOrientationModelEngine(combineRefHistogramWithRC, combineAltDepthOneHistogramWithRC, list5, this.convergenceThreshold, this.maxEMIterations, this.maxDepth, this.logger).learnPriorForArtifactStates());
                }
            }
            hashMap.put(key, artifactPriorCollection);
        }
        File createTempDir = IOUtils.createTempDir("priors");
        for (String str7 : hashMap.keySet()) {
            ((ArtifactPriorCollection) hashMap.get(str7)).writeArtifactPriors(new File(createTempDir, IOUtils.urlEncode(str7) + ".orientation_priors"));
        }
        try {
            IOUtils.writeTarGz(this.outputTarGz.getAbsolutePath(), createTempDir.listFiles());
            return "SUCCESS";
        } catch (IOException e) {
            throw new UserException.CouldNotCreateOutputFile("Could not create output .tar.gz file.", e);
        }
    }

    @VisibleForTesting
    public static Histogram<Integer> combineRefHistogramWithRC(String str, Histogram<Integer> histogram, Histogram<Integer> histogram2, int i) {
        Utils.validateArg(histogram.getValueLabel().equals(SequenceUtil.reverseComplement(histogram2.getValueLabel())), "ref context = " + histogram.getValueLabel() + ", rev comp = " + histogram2.getValueLabel());
        Utils.validateArg(histogram.getValueLabel().equals(str), "this better match");
        Histogram<Integer> createRefHistogram = F1R2FilterUtils.createRefHistogram(str, i);
        for (Integer num : histogram.keySet()) {
            createRefHistogram.increment(num, histogram.get(num).getValue() + histogram2.get(num).getValue());
        }
        return createRefHistogram;
    }

    @VisibleForTesting
    public static List<Histogram<Integer>> combineAltDepthOneHistogramWithRC(List<Histogram<Integer>> list, List<Histogram<Integer>> list2, int i) {
        if (list.isEmpty() && list2.isEmpty()) {
            return Collections.emptyList();
        }
        String reverseComplement = !list.isEmpty() ? (String) F1R2FilterUtils.labelToTriplet(list.get(0).getValueLabel()).getLeft() : SequenceUtil.reverseComplement((String) F1R2FilterUtils.labelToTriplet(list2.get(0).getValueLabel()).getLeft());
        Utils.validateArg(F1R2FilterConstants.CANONICAL_KMERS.contains(reverseComplement), "refContext must be the canonical representation but got " + reverseComplement);
        ArrayList arrayList = new ArrayList(F1R2FilterConstants.numAltHistogramsPerContext);
        for (Nucleotide nucleotide : Nucleotide.STANDARD_BASES) {
            if (nucleotide != F1R2FilterUtils.getMiddleBase(reverseComplement)) {
                String reverseComplement2 = SequenceUtil.reverseComplement(reverseComplement);
                Nucleotide valueOf = Nucleotide.valueOf(SequenceUtil.reverseComplement(nucleotide.toString()));
                for (ReadOrientation readOrientation : ReadOrientation.values()) {
                    ReadOrientation otherOrientation = ReadOrientation.getOtherOrientation(readOrientation);
                    Histogram<Integer> orElseGet = list.stream().filter(histogram -> {
                        return histogram.getValueLabel().equals(F1R2FilterUtils.tripletToLabel(reverseComplement, nucleotide, readOrientation));
                    }).findFirst().orElseGet(() -> {
                        return F1R2FilterUtils.createAltHistogram(reverseComplement, nucleotide, readOrientation, i);
                    });
                    Histogram<Integer> orElseGet2 = list2.stream().filter(histogram2 -> {
                        return histogram2.getValueLabel().equals(F1R2FilterUtils.tripletToLabel(reverseComplement2, valueOf, otherOrientation));
                    }).findFirst().orElseGet(() -> {
                        return F1R2FilterUtils.createAltHistogram(reverseComplement2, valueOf, otherOrientation, i);
                    });
                    Histogram<Integer> createAltHistogram = F1R2FilterUtils.createAltHistogram(reverseComplement, nucleotide, readOrientation, i);
                    for (Integer num : orElseGet.keySet()) {
                        createAltHistogram.increment(num, orElseGet.get(num).getValue() + orElseGet2.get(num).getValue());
                    }
                    arrayList.add(createAltHistogram);
                }
            }
        }
        return arrayList;
    }

    @VisibleForTesting
    public static void mergeDesignMatrices(List<AltSiteRecord> list, List<AltSiteRecord> list2) {
        if (list.isEmpty() && list2.isEmpty()) {
            return;
        }
        Utils.validateArg(list.isEmpty() || F1R2FilterConstants.CANONICAL_KMERS.contains(list.get(0).getReferenceContext()), "altDesignMatrix must have the canonical representation");
        Optional empty = list.isEmpty() ? Optional.empty() : Optional.of(list.get(0).getReferenceContext());
        Optional empty2 = list2.isEmpty() ? Optional.empty() : Optional.of(list2.get(0).getReferenceContext());
        if (empty.isPresent() && empty2.isPresent()) {
            Utils.validateArg(((String) empty.get()).equals(SequenceUtil.reverseComplement((String) empty2.get())), "ref context and its rev comp don't match");
        }
        list.addAll((Collection) list2.stream().map((v0) -> {
            return v0.getReverseComplementOfRecord();
        }).collect(Collectors.toList()));
    }

    public static MetricsFile<?, Integer> readMetricsFile(File file) {
        MetricsFile<?, Integer> metricsFile = new MetricsFile<>();
        BufferedReader openFileForBufferedReading = IOUtil.openFileForBufferedReading(file);
        metricsFile.read(openFileForBufferedReading);
        CloserUtil.close(openFileForBufferedReading);
        return metricsFile;
    }

    public static List<Histogram<Integer>> sumHistogramsFromFiles(List<MetricsFile<?, Integer>> list, boolean z) {
        Utils.nonNull(list, "files may not be null");
        if (list.isEmpty()) {
            return Collections.emptyList();
        }
        List<Histogram<Integer>> allHistograms = list.get(0).getAllHistograms();
        if (z) {
            Utils.validate(allHistograms.size() == F1R2FilterConstants.NUM_KMERS, "The list of ref histograms need to include all kmers as enforced by CollectF1R2Counts");
            Utils.validate(allHistograms.stream().allMatch(histogram -> {
                return F1R2FilterConstants.ALL_KMERS.contains(histogram.getValueLabel());
            }), "a histogram contains an unsupported, non-kmer header");
        } else {
            Utils.validate(allHistograms.size() == F1R2FilterConstants.NUM_KMERS * F1R2FilterConstants.numAltHistogramsPerContext, "The list of alt histograms missing some (kmer, alt allele, f1r2) triple");
        }
        for (int i = 1; i < list.size(); i++) {
            for (Histogram histogram2 : list.get(i).getAllHistograms()) {
                String valueLabel = histogram2.getValueLabel();
                Optional<Histogram<Integer>> findAny = allHistograms.stream().filter(histogram3 -> {
                    return histogram3.getValueLabel().equals(valueLabel);
                }).findAny();
                Utils.validate(findAny.isPresent(), "Missing histogram header for: " + valueLabel);
                findAny.get().addHistogram(histogram2);
            }
        }
        return allHistograms;
    }

    @VisibleForTesting
    static Map<String, List<AltSiteRecord>> gatherAltSiteRecords(List<File> list) {
        HashMap hashMap = new HashMap();
        Iterator<File> it = list.iterator();
        while (it.hasNext()) {
            Pair<String, List<AltSiteRecord>> readAltSiteRecords = AltSiteRecord.readAltSiteRecords(it.next().toPath(), 1000000);
            String str = (String) readAltSiteRecords.getLeft();
            List list2 = (List) readAltSiteRecords.getRight();
            if (hashMap.containsKey(str)) {
                ((List) hashMap.get(str)).addAll(list2);
            } else {
                hashMap.put(str, list2);
            }
        }
        return hashMap;
    }
}
