package org.broadinstitute.hellbender.tools.walkers.mutect;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.VariantContext;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.math3.util.FastMath;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.ReferenceConfidenceVariantContextMerger;
import org.broadinstitute.hellbender.tools.walkers.annotator.AssemblyComplexity;
import org.broadinstitute.hellbender.tools.walkers.annotator.ReferenceBases;
import org.broadinstitute.hellbender.tools.walkers.mutect.M2ArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2FilteringEngine;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.codecs.gtf.GencodeGtfFeature;
import org.broadinstitute.hellbender.utils.genotyper.AlleleLikelihoods;
import org.broadinstitute.hellbender.utils.haplotype.Haplotype;
import org.broadinstitute.hellbender.utils.read.Fragment;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
import org.broadinstitute.hellbender.utils.variant.VariantContextGetters;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/mutect/Mutect3DatasetEngine.class */
public class Mutect3DatasetEngine implements AutoCloseable {
    public static final int CAPACITY = 100000;
    private final SAMSequenceDictionary sequenceDictionary;
    private final Map<String, Integer> readGroupIndices = new HashMap();
    private static final int NUM_EXTRA_FEATURES = 9;
    private static final double RARE_POPAF_THRESHOLD = 5.9d;
    private static final double COMMON_POPAF_THRESHOLD = 1.0d;
    private static final double TLOD_THRESHOLD = 6.0d;
    private final int maxRefCount;
    private final int maxAltCount;
    private static final int MIN_REF = 5;
    private final PrintWriter printWriter;
    private final PrintWriter contigPrintWriter;
    private final PrintWriter readGroupPrintWriter;
    private final int nonArtifactPerArtifact;
    private final boolean trainingMode;
    private final Set<String> normalSamples;
    private final EnumMap<VariantType, ArrayBlockingQueue<Integer>> unmatchedArtifactAltCounts;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/mutect/Mutect3DatasetEngine$Label.class */
    public enum Label {
        ARTIFACT,
        VARIANT,
        UNLABELED,
        IGNORE
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/mutect/Mutect3DatasetEngine$VariantType.class */
    public enum VariantType {
        SNV,
        INSERTION,
        DELETION
    }

    public Mutect3DatasetEngine(File file, boolean z, int i, int i2, int i3, Set<String> set, SAMFileHeader sAMFileHeader, SAMSequenceDictionary sAMSequenceDictionary) {
        try {
            this.printWriter = new PrintWriter(new FileWriter((File) Utils.nonNull(file)));
            File file2 = file.toPath().resolveSibling("contigs.table").toFile();
            File file3 = file.toPath().resolveSibling("read-groups.table").toFile();
            this.contigPrintWriter = new PrintWriter(new FileWriter(file2));
            this.readGroupPrintWriter = new PrintWriter(new FileWriter(file3));
            this.normalSamples = (Set) Utils.nonNull(set);
            this.trainingMode = z;
            this.nonArtifactPerArtifact = i3;
            this.maxRefCount = i;
            this.maxAltCount = i2;
            this.sequenceDictionary = sAMSequenceDictionary;
            List readGroups = sAMFileHeader.getReadGroups();
            for (int i4 = 0; i4 < readGroups.size(); i4++) {
                this.readGroupIndices.put(((SAMReadGroupRecord) readGroups.get(i4)).getReadGroupId(), Integer.valueOf(i4));
            }
            this.unmatchedArtifactAltCounts = new EnumMap<>(VariantType.class);
            for (VariantType variantType : VariantType.values()) {
                this.unmatchedArtifactAltCounts.put((EnumMap<VariantType, ArrayBlockingQueue<Integer>>) variantType, (VariantType) new ArrayBlockingQueue<>(100000));
            }
        } catch (IOException e) {
            throw new UserException.BadInput("Could not create dataset file writer");
        }
    }

    public void addData(ReferenceContext referenceContext, VariantContext variantContext, Optional<List<VariantContext>> optional, AlleleLikelihoods<GATKRead, Allele> alleleLikelihoods, AlleleLikelihoods<Fragment, Haplotype> alleleLikelihoods2, AlleleLikelihoods<Fragment, Allele> alleleLikelihoods3, M2ArgumentCollection.Mutect3DatasetMode mutect3DatasetMode) {
        String annotate = ReferenceBases.annotate(referenceContext, variantContext);
        String baseString = variantContext.getReference().getBaseString();
        int sequenceIndex = this.sequenceDictionary.getSequenceIndex(variantContext.getContig());
        int start = variantContext.getStart();
        Set<String> set = (Set) alleleLikelihoods.samples().stream().filter(str -> {
            return !this.normalSamples.contains(str);
        }).collect(Collectors.toSet());
        int nAlleles = variantContext.getNAlleles() - 1;
        double[] attributeAsDoubleArray = VariantContextGetters.getAttributeAsDoubleArray(variantContext, GATKVCFConstants.POPULATION_AF_KEY);
        double[] tumorLogOdds = Mutect2FilteringEngine.getTumorLogOdds(variantContext);
        int[] sumADsOverSamples = sumADsOverSamples(variantContext, set);
        int[] sumADsOverSamples2 = sumADsOverSamples(variantContext, this.normalSamples);
        int sum = (int) MathUtils.sum(sumADsOverSamples);
        int sum2 = (int) MathUtils.sum(sumADsOverSamples2);
        boolean z = sum2 > 0;
        ArrayList arrayList = new ArrayList();
        arrayList.add(variantContext.getReference());
        optional.ifPresent(list -> {
            list.forEach(variantContext2 -> {
                arrayList.add(variantContext2.getReference());
            });
        });
        Allele allele = (Allele) arrayList.stream().sorted(Comparator.comparingInt((v0) -> {
            return v0.length();
        }).reversed()).findFirst().get();
        List<Allele> list2 = ReferenceConfidenceVariantContextMerger.remapAlleles(variantContext, allele).stream().skip(1L).toList();
        Set emptySet = !optional.isPresent() ? Collections.emptySet() : (Set) optional.get().stream().filter(variantContext2 -> {
            return !variantContext2.isFiltered();
        }).flatMap(variantContext3 -> {
            return ReferenceConfidenceVariantContextMerger.remapAlleles(variantContext3, allele).stream();
        }).collect(Collectors.toSet());
        ArrayList arrayList2 = new ArrayList(nAlleles);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < nAlleles; i++) {
            double d = sumADsOverSamples[i + 1] / sum;
            double d2 = z ? sumADsOverSamples2[i + 1] / sum2 : 0.0d;
            Allele alternateAllele = variantContext.getAlternateAllele(i);
            Allele allele2 = list2.get(i);
            int length = alternateAllele.getBaseString().length() - baseString.length();
            VariantType variantType = length == 0 ? VariantType.SNV : length > 0 ? VariantType.INSERTION : VariantType.DELETION;
            if (this.trainingMode) {
                ArrayBlockingQueue<Integer> arrayBlockingQueue = this.unmatchedArtifactAltCounts.get(variantType);
                boolean z2 = tumorLogOdds[i] < TLOD_THRESHOLD;
                boolean contains = optional.isPresent() ? emptySet.contains(allele2) : !z2 && attributeAsDoubleArray[i] < 1.0d && d > 0.35d && (!z || d2 > 0.35d);
                if (!z2 && (!optional.isPresent() ? d >= 0.2d || attributeAsDoubleArray[i] <= RARE_POPAF_THRESHOLD : emptySet.contains(allele2))) {
                    if (arrayBlockingQueue.size() > 90000.0d) {
                        arrayList2.add(Label.IGNORE);
                    } else {
                        arrayList2.add(Label.ARTIFACT);
                        arrayBlockingQueue.addAll(Collections.nCopies(this.nonArtifactPerArtifact, Integer.valueOf(sumADsOverSamples[i + 1])));
                    }
                } else if (contains && !arrayBlockingQueue.isEmpty()) {
                    arrayList2.add(Label.VARIANT);
                    hashMap.put(alternateAllele, arrayBlockingQueue.poll());
                } else if (tumorLogOdds[i] <= 4.0d || d >= 0.3d) {
                    arrayList2.add(Label.IGNORE);
                } else {
                    arrayList2.add(Label.UNLABELED);
                }
            } else if (optional.isPresent()) {
                arrayList2.add(emptySet.contains(allele2) ? Label.VARIANT : Label.ARTIFACT);
            } else {
                arrayList2.add(Label.UNLABELED);
            }
        }
        Utils.validate(arrayList2.size() == nAlleles, "We have not labeled every alt, or have labeled too much");
        if (this.trainingMode && arrayList2.stream().allMatch(label -> {
            return label == Label.IGNORE;
        })) {
            return;
        }
        Triple<int[], int[], double[]> annotate2 = AssemblyComplexity.annotate(variantContext, alleleLikelihoods2, false);
        List<List<List<Integer>>> readVectors = FeaturizedReadSets.getReadVectors(variantContext, this.normalSamples, alleleLikelihoods, alleleLikelihoods2, this.maxRefCount, this.maxAltCount, mutect3DatasetMode, this.readGroupIndices);
        List<List<List<Integer>>> readVectors2 = FeaturizedReadSets.getReadVectors(variantContext, set, alleleLikelihoods, alleleLikelihoods2, this.maxRefCount, this.maxAltCount, hashMap, mutect3DatasetMode, this.readGroupIndices);
        List<List<Integer>> list3 = readVectors2.get(0);
        readVectors.get(0);
        for (int i2 = 0; i2 < nAlleles; i2++) {
            if (arrayList2.get(i2) != Label.IGNORE) {
                String baseString2 = variantContext.getAlternateAllele(i2).getBaseString();
                List<Double> variantFeatures = variantFeatures(i2, annotate2, annotate);
                List<List<Integer>> list4 = readVectors2.get(i2 + 1);
                readVectors.get(i2 + 1);
                this.printWriter.println(((Label) arrayList2.get(i2)).toString());
                this.printWriter.printf("%d:%d,%s->%s%n", Integer.valueOf(sequenceIndex), Integer.valueOf(start), baseString, baseString2);
                this.printWriter.println(annotate);
                this.printWriter.println(numberString(variantFeatures, "%.2f", GencodeGtfFeature.EXTRA_FIELD_KEY_VALUE_SPLITTER));
                this.printWriter.printf("%d %d %d %d%n", Integer.valueOf(list3.size()), Integer.valueOf(list4.size()), 0, 0);
                list3.forEach(list5 -> {
                    this.printWriter.println(integerString(list5));
                });
                list4.forEach(list6 -> {
                    this.printWriter.println(integerString(list6));
                });
                this.printWriter.printf("%d %d %d %d%n", Integer.valueOf(sum), Integer.valueOf(sumADsOverSamples[i2 + 1]), Integer.valueOf(sum2), Integer.valueOf(sumADsOverSamples2[i2 + 1]));
                this.printWriter.printf("%.3f%n", Double.valueOf((-MathUtils.log10ToLog(((Double) variantContext.getAttributeAsDoubleList(GATKVCFConstants.TUMOR_LOG_10_ODDS_KEY, 0.0d).get(i2)).doubleValue())) - Math.log(sum + 1)));
                this.printWriter.printf("%.3f%n", Double.valueOf((-MathUtils.log10ToLog(this.normalSamples.isEmpty() ? 0.0d : ((Double) variantContext.getAttributeAsDoubleList(GATKVCFConstants.NORMAL_ARTIFACT_LOG_10_ODDS_KEY, 0.0d).get(i2)).doubleValue())) - Math.log(sum2 + 1)));
            }
        }
    }

    private String integerString(List<Integer> list) {
        return numberString(list, "%d", GencodeGtfFeature.EXTRA_FIELD_KEY_VALUE_SPLITTER);
    }

    private String numberString(List<? extends Number> list, String str, String str2) {
        boolean endsWith = str.endsWith("f");
        return (String) list.stream().map(number -> {
            Object[] objArr = new Object[1];
            objArr[0] = endsWith ? Float.valueOf(number.floatValue()) : number;
            return String.format(str, objArr);
        }).collect(Collectors.joining(str2));
    }

    private List<Double> variantFeatures(int i, Triple<int[], int[], double[]> triple, String str) {
        int[] iArr = (int[]) triple.getLeft();
        int i2 = ((int[]) triple.getMiddle())[i];
        double d = ((double[]) triple.getRight())[i];
        ArrayList arrayList = new ArrayList(9);
        double sum = MathUtils.sum(iArr);
        arrayList.add(Double.valueOf(iArr.length < 2 ? 0.0d : iArr[1] / sum));
        arrayList.add(Double.valueOf(iArr.length < 3 ? 0.0d : iArr[2] / sum));
        arrayList.add(Double.valueOf(i2));
        arrayList.add(Double.valueOf(d));
        IntStream.range(1, 6).forEach(i3 -> {
            arrayList.add(Double.valueOf(countRepeats(str.getBytes(), i3)));
        });
        Utils.validate(arrayList.size() == 9, "produced a variant feature vector of wrong size");
        return arrayList;
    }

    private int countRepeats(byte[] bArr, int i) {
        int length = bArr.length;
        int i2 = (length - 1) / 2;
        Utils.validateArg(i <= i2, "Too few ref bases for given repeat length");
        int i3 = i2 + i;
        while (i3 < length && bArr[i3] == bArr[i3 - i]) {
            i3++;
        }
        int i4 = i2 - 1;
        while (i4 >= 0 && bArr[i4] == bArr[i4 + i]) {
            i4--;
        }
        int i5 = ((i3 - i4) - 1) / i;
        int i6 = i2 - i;
        while (i6 >= 0 && bArr[i6] == bArr[i6 + i]) {
            i6--;
        }
        int i7 = i2 + 1;
        while (i7 < length && bArr[i7] == bArr[i7 - i]) {
            i7++;
        }
        return FastMath.max(i5, ((i7 - i6) - 1) / i);
    }

    private int[] sumADsOverSamples(VariantContext variantContext, Set<String> set) {
        int[] iArr = new int[variantContext.getNAlleles()];
        variantContext.getGenotypes(set).stream().map((v0) -> {
            return v0.getAD();
        }).forEach(iArr2 -> {
            new IndexRange(0, variantContext.getNAlleles()).forEach(i -> {
                iArr[i] = iArr[i] + iArr2[i];
            });
        });
        return iArr;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.printWriter.close();
        for (SAMSequenceRecord sAMSequenceRecord : this.sequenceDictionary.getSequences()) {
            this.contigPrintWriter.println(String.format("%s\t%d", sAMSequenceRecord.getContig(), Integer.valueOf(sAMSequenceRecord.getSequenceIndex())));
        }
        for (Map.Entry<String, Integer> entry : this.readGroupIndices.entrySet()) {
            this.readGroupPrintWriter.println(String.format("%s\t%d", entry.getKey(), entry.getValue()));
        }
        this.contigPrintWriter.close();
        this.readGroupPrintWriter.close();
    }
}
