package org.broadinstitute.hellbender.tools.walkers.readorientation;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.metrics.MetricsFile;
import htsjdk.samtools.util.CloserUtil;
import htsjdk.samtools.util.Histogram;
import htsjdk.samtools.util.IOUtil;
import htsjdk.samtools.util.SequenceUtil;
import java.io.BufferedReader;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.utils.Nucleotide;
import org.broadinstitute.hellbender.utils.Utils;

@CommandLineProgramProperties(summary = "Get the maximum likelihood estimates of artifact prior probabilities in the orientation bias mixture model filter", oneLineSummary = "Get the maximum likelihood estimates of artifact prior probabilities in the orientation bias mixture model filter", programGroup = ShortVariantDiscoveryProgramGroup.class)
/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/readorientation/LearnReadOrientationModel.class */
public class LearnReadOrientationModel extends CommandLineProgram {
    public static final double DEFAULT_CONVERGENCE_THRESHOLD = 1.0E-4d;
    public static final int DEFAULT_MAX_ITERATIONS = 20;
    public static final String EM_CONVERGENCE_THRESHOLD_LONG_NAME = "convergence-threshold";
    public static final String MAX_EM_ITERATIONS_LONG_NAME = "num-em-iterations";
    public static final String MAX_DEPTH_LONG_NAME = "max-depth";

    @Argument(fullName = CollectF1R2Counts.REF_SITE_METRICS_LONG_NAME, doc = "histograms of depths over ref sites for each reference context")
    private File refHistogramTable;

    @Argument(fullName = CollectF1R2Counts.ALT_DATA_TABLE_LONG_NAME, doc = "a table of F1R2 and depth counts")
    private File altDataTable;

    @Argument(fullName = "output", shortName = "O", doc = "table of artifact priors")
    private File output;
    List<Histogram<Integer>> refHistograms;
    List<Histogram<Integer>> altHistograms;

    @Argument(fullName = CollectF1R2Counts.ALT_DEPTH1_HISTOGRAM_LONG_NAME, doc = "histograms of depth 1 alt sites", optional = true)
    private File altHistogramTable = null;

    @Argument(fullName = EM_CONVERGENCE_THRESHOLD_LONG_NAME, doc = "Stop the EM when the distance between parameters between iterations falls below this value", optional = true)
    private double converagenceThreshold = 1.0E-4d;

    @Argument(fullName = MAX_EM_ITERATIONS_LONG_NAME, doc = "give up on EM after this many iterations", optional = true)
    private int maxEMIterations = 20;

    @Argument(fullName = "max-depth", doc = "sites with depth higher than this value will be grouped", optional = true)
    private int maxDepth = 200;
    final ArtifactPriorCollection artifactPriorCollection = new ArtifactPriorCollection();

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.broadinstitute.hellbender.cmdline.CommandLineProgram
    public void onStartup() {
        this.refHistograms = readMetricsFile(this.refHistogramTable).getAllHistograms();
        if (this.altHistogramTable != null) {
            this.altHistograms = readMetricsFile(this.altHistogramTable).getAllHistograms();
        } else {
            this.altHistograms = Collections.emptyList();
        }
    }

    @Override // org.broadinstitute.hellbender.cmdline.CommandLineProgram
    public Object doWork() {
        Map map = (Map) AltSiteRecord.readAltSiteRecords(this.altDataTable, 1000000).stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getReferenceContext();
        }));
        for (String str : F1R2FilterConstants.CANONICAL_KMERS) {
            String reverseComplement = SequenceUtil.reverseComplement(str);
            Histogram<Integer> combineRefHistogramWithRC = combineRefHistogramWithRC(str, this.refHistograms.stream().filter(histogram -> {
                return histogram.getValueLabel().equals(str);
            }).findFirst().orElseGet(() -> {
                return F1R2FilterUtils.createRefHistogram(str, this.maxDepth);
            }), this.refHistograms.stream().filter(histogram2 -> {
                return histogram2.getValueLabel().equals(reverseComplement);
            }).findFirst().orElseGet(() -> {
                return F1R2FilterUtils.createRefHistogram(reverseComplement, this.maxDepth);
            }), this.maxDepth);
            List<Histogram<Integer>> combineAltDepthOneHistogramWithRC = combineAltDepthOneHistogramWithRC((List) this.altHistograms.stream().filter(histogram3 -> {
                return histogram3.getValueLabel().startsWith(str);
            }).collect(Collectors.toList()), (List) this.altHistograms.stream().filter(histogram4 -> {
                return histogram4.getValueLabel().startsWith(reverseComplement);
            }).collect(Collectors.toList()), this.maxDepth);
            List list = (List) map.getOrDefault(str, new ArrayList());
            mergeDesignMatrices(list, (List) map.getOrDefault(reverseComplement, Collections.emptyList()));
            if (combineRefHistogramWithRC.getSumOfValues() == 0.0d || list.isEmpty()) {
                this.logger.info(String.format("Skipping the reference context %s as we didn't find either the ref or alt table for the context", str));
            } else {
                this.artifactPriorCollection.set(new LearnReadOrientationModelEngine(combineRefHistogramWithRC, combineAltDepthOneHistogramWithRC, list, this.converagenceThreshold, this.maxEMIterations, this.maxDepth, this.logger).learnPriorForArtifactStates());
            }
        }
        this.artifactPriorCollection.writeArtifactPriors(this.output);
        return "SUCCESS";
    }

    @VisibleForTesting
    public static Histogram<Integer> combineRefHistogramWithRC(String str, Histogram<Integer> histogram, Histogram<Integer> histogram2, int i) {
        Utils.validateArg(histogram.getValueLabel().equals(SequenceUtil.reverseComplement(histogram2.getValueLabel())), "ref context = " + histogram.getValueLabel() + ", rev comp = " + histogram2.getValueLabel());
        Utils.validateArg(histogram.getValueLabel().equals(str), "this better match");
        Histogram<Integer> createRefHistogram = F1R2FilterUtils.createRefHistogram(str, i);
        for (Integer num : histogram.keySet()) {
            createRefHistogram.increment(num, histogram.get(num).getValue() + histogram2.get(num).getValue());
        }
        return createRefHistogram;
    }

    @VisibleForTesting
    public static List<Histogram<Integer>> combineAltDepthOneHistogramWithRC(List<Histogram<Integer>> list, List<Histogram<Integer>> list2, int i) {
        if (list.isEmpty() && list2.isEmpty()) {
            return Collections.emptyList();
        }
        String reverseComplement = !list.isEmpty() ? (String) F1R2FilterUtils.labelToTriplet(list.get(0).getValueLabel()).getLeft() : SequenceUtil.reverseComplement((String) F1R2FilterUtils.labelToTriplet(list2.get(0).getValueLabel()).getLeft());
        Utils.validateArg(F1R2FilterConstants.CANONICAL_KMERS.contains(reverseComplement), "refContext must be the canonical representation but got " + reverseComplement);
        ArrayList arrayList = new ArrayList(F1R2FilterConstants.numAltHistogramsPerContext);
        for (Nucleotide nucleotide : Nucleotide.STANDARD_BASES) {
            if (nucleotide != F1R2FilterUtils.getMiddleBase(reverseComplement)) {
                String reverseComplement2 = SequenceUtil.reverseComplement(reverseComplement);
                Nucleotide valueOf = Nucleotide.valueOf(SequenceUtil.reverseComplement(nucleotide.toString()));
                for (ReadOrientation readOrientation : ReadOrientation.values()) {
                    ReadOrientation otherOrientation = ReadOrientation.getOtherOrientation(readOrientation);
                    Histogram<Integer> orElseGet = list.stream().filter(histogram -> {
                        return histogram.getValueLabel().equals(F1R2FilterUtils.tripletToLabel(reverseComplement, nucleotide, readOrientation));
                    }).findFirst().orElseGet(() -> {
                        return F1R2FilterUtils.createAltHistogram(reverseComplement, nucleotide, readOrientation, i);
                    });
                    Histogram<Integer> orElseGet2 = list2.stream().filter(histogram2 -> {
                        return histogram2.getValueLabel().equals(F1R2FilterUtils.tripletToLabel(reverseComplement2, valueOf, otherOrientation));
                    }).findFirst().orElseGet(() -> {
                        return F1R2FilterUtils.createAltHistogram(reverseComplement2, valueOf, otherOrientation, i);
                    });
                    Histogram<Integer> createAltHistogram = F1R2FilterUtils.createAltHistogram(reverseComplement, nucleotide, readOrientation, i);
                    for (Integer num : orElseGet.keySet()) {
                        createAltHistogram.increment(num, orElseGet.get(num).getValue() + orElseGet2.get(num).getValue());
                    }
                    arrayList.add(createAltHistogram);
                }
            }
        }
        return arrayList;
    }

    @VisibleForTesting
    public static void mergeDesignMatrices(List<AltSiteRecord> list, List<AltSiteRecord> list2) {
        if (list.isEmpty() && list2.isEmpty()) {
            return;
        }
        Utils.validateArg(list.isEmpty() || F1R2FilterConstants.CANONICAL_KMERS.contains(list.get(0).getReferenceContext()), "altDesignMatrix must have the canonical representation");
        Optional empty = list.isEmpty() ? Optional.empty() : Optional.of(list.get(0).getReferenceContext());
        Optional empty2 = list2.isEmpty() ? Optional.empty() : Optional.of(list2.get(0).getReferenceContext());
        if (empty.isPresent() && empty2.isPresent()) {
            Utils.validateArg(((String) empty.get()).equals(SequenceUtil.reverseComplement((String) empty2.get())), "ref context and its rev comp don't match");
        }
        list.addAll((Collection) list2.stream().map((v0) -> {
            return v0.getReverseComplementOfRecord();
        }).collect(Collectors.toList()));
    }

    private MetricsFile<?, Integer> readMetricsFile(File file) {
        MetricsFile<?, Integer> metricsFile = new MetricsFile<>();
        BufferedReader openFileForBufferedReading = IOUtil.openFileForBufferedReading(file);
        metricsFile.read(openFileForBufferedReading);
        CloserUtil.close(openFileForBufferedReading);
        return metricsFile;
    }
}
