package org.broadinstitute.hellbender.tools.spark.sv.evidence;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.learner.ObjFunction;
import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.CigarElement;
import htsjdk.samtools.CigarOperator;
import htsjdk.samtools.TextCigarCodec;
import htsjdk.samtools.util.IOUtil;
import htsjdk.tribble.bed.BEDFeature;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.broadinstitute.hellbender.engine.FeatureDataSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.funcotator.vcfOutput.VcfOutputRenderer;
import org.broadinstitute.hellbender.tools.spark.sv.StructuralVariationDiscoveryArgumentCollection;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.BreakpointEvidence;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.EvidenceOverlapChecker;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVInterval;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVIntervalTree;
import org.broadinstitute.hellbender.tools.spark.utils.IntHistogram;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.io.IOUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/spark/sv/evidence/XGBoostEvidenceFilter.class */
public final class XGBoostEvidenceFilter implements Iterator<BreakpointEvidence> {
    private static final boolean USE_FAST_MATH_EXP = true;
    private static final List<Class<?>> DEFAULT_EVIDENCE_TYPE_ORDER = Arrays.asList(BreakpointEvidence.TemplateSizeAnomaly.class, BreakpointEvidence.MateUnmapped.class, BreakpointEvidence.InterContigPair.class, BreakpointEvidence.SplitRead.class, BreakpointEvidence.LargeIndel.class, BreakpointEvidence.WeirdTemplateSize.class, BreakpointEvidence.SameStrandPair.class, BreakpointEvidence.OutiesPair.class);
    private static final Map<Class<?>, Integer> evidenceTypeMap = evidenceTypeOrderToImmutableMap(DEFAULT_EVIDENCE_TYPE_ORDER);
    private static final String DEFAULT_PREDICTOR_RESOURCE_PATH = "/large/sv_evidence_classifier.bin";
    private static final double DEFAULT_GOOD_GAP_OVERLAP = 0.0d;
    private static final double DEFAULT_GOOD_MAPPABILITY = 1.0d;
    private static final int DEFAULT_GOOD_MAPPING_QUALITY = 60;
    private static final double NON_READ_MAPPING_QUALITY = 60.0d;
    private static final double NON_READ_CIGAR_LENGTHS = 0.0d;
    private final PartitionCrossingChecker partitionCrossingChecker;
    private final Predictor predictor;
    private final double thresholdProbability;
    private final ReadMetadata readMetadata;
    private final EvidenceOverlapChecker evidenceOverlapChecker;
    private final Map<BreakpointEvidence, UnscaledOverlapInfo> rawFeatureCache;
    private Iterator<SVIntervalTree.Entry<List<BreakpointEvidence>>> treeItr;
    private Iterator<BreakpointEvidence> listItr;
    private final FeatureDataSource<BEDFeature> genomeGaps;
    private final FeatureDataSource<BEDFeature> umapS100Mappability;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/spark/sv/evidence/XGBoostEvidenceFilter$CigarQualityInfo.class */
    public static class CigarQualityInfo {
        final double basesMatched;
        final double referenceLength;

        CigarQualityInfo(BreakpointEvidence breakpointEvidence) {
            if (!(breakpointEvidence instanceof BreakpointEvidence.ReadEvidence)) {
                this.basesMatched = 0.0d;
                this.referenceLength = 0.0d;
                return;
            }
            int i = 0;
            int i2 = 0;
            for (CigarElement cigarElement : TextCigarCodec.decode(((BreakpointEvidence.ReadEvidence) breakpointEvidence).getCigarString()).getCigarElements()) {
                CigarOperator operator = cigarElement.getOperator();
                if (operator.consumesReferenceBases()) {
                    i2 += cigarElement.getLength();
                    if (operator.consumesReadBases()) {
                        i += cigarElement.getLength();
                    }
                }
            }
            this.basesMatched = i;
            this.referenceLength = i2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/spark/sv/evidence/XGBoostEvidenceFilter$CoverageScaledOverlapInfo.class */
    public static class CoverageScaledOverlapInfo {
        final double numOverlap;
        final double totalOverlapMappingQuality;
        final double meanOverlapMappingQuality;
        final double numCoherent;
        final double totalCoherentMappingQuality;

        CoverageScaledOverlapInfo(int i, int i2, int i3, int i4, double d, double d2) {
            this.numOverlap = i / d2;
            this.totalOverlapMappingQuality = i3 / d2;
            this.numCoherent = i2 / d2;
            this.totalCoherentMappingQuality = i4 / d2;
            this.meanOverlapMappingQuality = d;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/spark/sv/evidence/XGBoostEvidenceFilter$UnscaledOverlapInfo.class */
    public static class UnscaledOverlapInfo {
        final int numOverlap;
        final int numCoherent;
        final int totalOverlapMappingQuality;
        final int totalCoherentMappingQuality;
        final double meanOverlapMappingQuality;

        UnscaledOverlapInfo(int i, int i2, int i3, int i4) {
            this.numOverlap = i;
            this.numCoherent = i2;
            this.totalOverlapMappingQuality = i3;
            this.totalCoherentMappingQuality = i4;
            this.meanOverlapMappingQuality = this.totalOverlapMappingQuality / i;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public XGBoostEvidenceFilter(Iterator<BreakpointEvidence> it, ReadMetadata readMetadata, StructuralVariationDiscoveryArgumentCollection.FindBreakpointEvidenceSparkArgumentCollection findBreakpointEvidenceSparkArgumentCollection, PartitionCrossingChecker partitionCrossingChecker) {
        if (findBreakpointEvidenceSparkArgumentCollection.svGenomeGapsFile == null && findBreakpointEvidenceSparkArgumentCollection.runWithoutGapsAnnotation) {
            this.genomeGaps = null;
        } else {
            if (findBreakpointEvidenceSparkArgumentCollection.svGenomeGapsFile == null || findBreakpointEvidenceSparkArgumentCollection.runWithoutGapsAnnotation) {
                throw new IllegalArgumentException("XGBoostEvidenceFilter requires specifying --sv-genome-gaps-file or passing --run-without-gaps-annotation (but not both)");
            }
            this.genomeGaps = new FeatureDataSource<>(findBreakpointEvidenceSparkArgumentCollection.svGenomeGapsFile);
        }
        if (findBreakpointEvidenceSparkArgumentCollection.svGenomeUmapS100File == null && findBreakpointEvidenceSparkArgumentCollection.runWithoutUmapS100Annotation) {
            this.umapS100Mappability = null;
        } else {
            if (findBreakpointEvidenceSparkArgumentCollection.svGenomeUmapS100File == null || findBreakpointEvidenceSparkArgumentCollection.runWithoutUmapS100Annotation) {
                throw new IllegalArgumentException("XGBoostEvidenceFilter requires specifying --sv-genome-umap-s100-file or passing --run-without-umap-s100-annotation (but not both)");
            }
            this.umapS100Mappability = new FeatureDataSource<>(findBreakpointEvidenceSparkArgumentCollection.svGenomeUmapS100File);
        }
        this.predictor = loadPredictor(findBreakpointEvidenceSparkArgumentCollection.svEvidenceFilterModelFile);
        this.partitionCrossingChecker = partitionCrossingChecker;
        this.thresholdProbability = findBreakpointEvidenceSparkArgumentCollection.svEvidenceFilterThresholdProbability;
        this.readMetadata = readMetadata;
        this.evidenceOverlapChecker = new EvidenceOverlapChecker(it, readMetadata, findBreakpointEvidenceSparkArgumentCollection.minEvidenceMapQ);
        this.rawFeatureCache = new HashMap();
        this.listItr = null;
        this.treeItr = this.evidenceOverlapChecker.getTreeIterator();
    }

    private static Map<Class<?>, Integer> evidenceTypeOrderToImmutableMap(List<Class<?>> list) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            hashMap.put(list.get(i), Integer.valueOf(i));
        }
        return Collections.unmodifiableMap(hashMap);
    }

    public static Predictor loadPredictor(String str) {
        ObjFunction.useFastMathExp(true);
        try {
            InputStream resourcePathToInputStream = str == null ? resourcePathToInputStream(DEFAULT_PREDICTOR_RESOURCE_PATH) : BucketUtils.openFile(str);
            Throwable th = null;
            try {
                Predictor predictor = new Predictor(resourcePathToInputStream);
                if (resourcePathToInputStream != null) {
                    if (0 != 0) {
                        try {
                            resourcePathToInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        resourcePathToInputStream.close();
                    }
                }
                return predictor;
            } finally {
            }
        } catch (Exception e) {
            throw new GATKException("Unable to load predictor from classifier file " + (str == null ? DEFAULT_PREDICTOR_RESOURCE_PATH : str) + VcfOutputRenderer.DESCRIPTION_PREAMBLE_DELIMITER + e.getMessage());
        }
    }

    private static InputStream resourcePathToInputStream(String str) throws IOException {
        InputStream resourceAsStream = XGBoostEvidenceFilter.class.getResourceAsStream(str);
        return IOUtil.hasBlockCompressedExtension(str) ? IOUtils.makeZippedInputStream(new BufferedInputStream(resourceAsStream)) : resourceAsStream;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        if (this.listItr != null && this.listItr.hasNext()) {
            return true;
        }
        this.listItr = null;
        boolean z = false;
        while (!z && this.treeItr.hasNext()) {
            SVIntervalTree.Entry<List<BreakpointEvidence>> next = this.treeItr.next();
            SVInterval interval = next.getInterval();
            List<BreakpointEvidence> value = next.getValue();
            if (isValidated(next.getValue()) || this.partitionCrossingChecker.onBoundary(interval)) {
                z = true;
            } else if (anyPassesFilter(value)) {
                value.forEach(breakpointEvidence -> {
                    breakpointEvidence.setValidated(true);
                });
                z = true;
            }
            if (z) {
                this.listItr = next.getValue().iterator();
            }
        }
        return z;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public BreakpointEvidence next() {
        if (hasNext()) {
            return this.listItr.next();
        }
        throw new NoSuchElementException("No next element.");
    }

    private boolean isValidated(List<BreakpointEvidence> list) {
        Iterator<BreakpointEvidence> it = list.iterator();
        while (it.hasNext()) {
            if (it.next().isValidated()) {
                return true;
            }
        }
        return false;
    }

    private boolean anyPassesFilter(List<BreakpointEvidence> list) {
        Iterator<BreakpointEvidence> it = list.iterator();
        while (it.hasNext()) {
            if (predictProbability(it.next()) > this.thresholdProbability) {
                return true;
            }
        }
        return false;
    }

    @VisibleForTesting
    double predictProbability(BreakpointEvidence breakpointEvidence) {
        return this.predictor.predictSingle(getFeatures(breakpointEvidence));
    }

    @VisibleForTesting
    EvidenceFeatures getFeatures(BreakpointEvidence breakpointEvidence) {
        CigarQualityInfo cigarQualityInfo = new CigarQualityInfo(breakpointEvidence);
        double intValue = evidenceTypeMap.get(breakpointEvidence.getClass()).intValue();
        double mappingQuality = getMappingQuality(breakpointEvidence);
        CoverageScaledOverlapInfo individualOverlapInfo = getIndividualOverlapInfo(breakpointEvidence);
        CoverageScaledOverlapInfo clusterOverlapInfo = getClusterOverlapInfo(breakpointEvidence);
        return new EvidenceFeatures(new double[]{cigarQualityInfo.basesMatched, cigarQualityInfo.referenceLength, intValue, mappingQuality, getTemplateSizeOrReadCount(breakpointEvidence), individualOverlapInfo.numOverlap, individualOverlapInfo.totalOverlapMappingQuality, individualOverlapInfo.meanOverlapMappingQuality, individualOverlapInfo.numCoherent, individualOverlapInfo.totalCoherentMappingQuality, clusterOverlapInfo.numOverlap, clusterOverlapInfo.totalOverlapMappingQuality, clusterOverlapInfo.meanOverlapMappingQuality, clusterOverlapInfo.numCoherent, clusterOverlapInfo.totalCoherentMappingQuality, this.genomeGaps == null ? 0.0d : getGenomeIntervalsOverlap(breakpointEvidence, this.genomeGaps, this.readMetadata), this.umapS100Mappability == null ? DEFAULT_GOOD_MAPPABILITY : getGenomeIntervalsOverlap(breakpointEvidence, this.umapS100Mappability, this.readMetadata)});
    }

    private double getMappingQuality(BreakpointEvidence breakpointEvidence) {
        return breakpointEvidence instanceof BreakpointEvidence.ReadEvidence ? ((BreakpointEvidence.ReadEvidence) breakpointEvidence).getMappingQuality() : NON_READ_MAPPING_QUALITY;
    }

    private int getMappingQualityForOverlap(BreakpointEvidence breakpointEvidence) {
        if (breakpointEvidence instanceof BreakpointEvidence.ReadEvidence) {
            return ((BreakpointEvidence.ReadEvidence) breakpointEvidence).getMappingQuality();
        }
        return 60;
    }

    private double getTemplateSizeOrReadCount(BreakpointEvidence breakpointEvidence) {
        if (breakpointEvidence instanceof BreakpointEvidence.ReadEvidence) {
            return getTemplateSize((BreakpointEvidence.ReadEvidence) breakpointEvidence);
        }
        if (breakpointEvidence instanceof BreakpointEvidence.TemplateSizeAnomaly) {
            return getReadCounts((BreakpointEvidence.TemplateSizeAnomaly) breakpointEvidence);
        }
        throw new IllegalStateException("templateSizeOrReadCount feature is only defined for ReadEvidence and TemplateSizeAnomaly, not " + breakpointEvidence.getClass().getName());
    }

    private double getTemplateSize(BreakpointEvidence.ReadEvidence readEvidence) {
        int templateSize = readEvidence.getTemplateSize();
        IntHistogram.CDF cdf = this.readMetadata.getLibraryStatistics(this.readMetadata.getReadGroupToLibraryMap().get(readEvidence.getReadGroup())).getCDF();
        return cdf.getFraction(Integer.min(Math.abs(templateSize), cdf.size() - 1));
    }

    private double getReadCounts(BreakpointEvidence.TemplateSizeAnomaly templateSizeAnomaly) {
        return templateSizeAnomaly.getReadCount().intValue() / this.readMetadata.getCoverage();
    }

    private CoverageScaledOverlapInfo getIndividualOverlapInfo(BreakpointEvidence breakpointEvidence) {
        if (!this.rawFeatureCache.containsKey(breakpointEvidence)) {
            cacheOverlapInfo(breakpointEvidence);
        }
        UnscaledOverlapInfo unscaledOverlapInfo = this.rawFeatureCache.get(breakpointEvidence);
        return new CoverageScaledOverlapInfo(unscaledOverlapInfo.numOverlap, unscaledOverlapInfo.numCoherent, unscaledOverlapInfo.totalOverlapMappingQuality, unscaledOverlapInfo.totalCoherentMappingQuality, unscaledOverlapInfo.meanOverlapMappingQuality, this.readMetadata.getCoverage());
    }

    private CoverageScaledOverlapInfo getClusterOverlapInfo(BreakpointEvidence breakpointEvidence) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        double d = 0.0d;
        EvidenceOverlapChecker.OverlapperIterator overlappers = this.evidenceOverlapChecker.overlappers(breakpointEvidence);
        while (overlappers.hasNext()) {
            BreakpointEvidence next = overlappers.next();
            if (!next.equals(breakpointEvidence)) {
                if (!this.rawFeatureCache.containsKey(next)) {
                    cacheOverlapInfo(next);
                }
                UnscaledOverlapInfo unscaledOverlapInfo = this.rawFeatureCache.get(next);
                i = Math.max(i, unscaledOverlapInfo.numOverlap);
                i2 = Math.max(i2, unscaledOverlapInfo.numCoherent);
                i3 = Math.max(i3, unscaledOverlapInfo.totalOverlapMappingQuality);
                i4 = Math.max(i4, unscaledOverlapInfo.totalCoherentMappingQuality);
                d = Math.max(d, unscaledOverlapInfo.meanOverlapMappingQuality);
            }
        }
        return new CoverageScaledOverlapInfo(i, i2, i3, i4, d, this.readMetadata.getCoverage());
    }

    private void cacheOverlapInfo(BreakpointEvidence breakpointEvidence) {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        EvidenceOverlapChecker.OverlapAndCoherenceIterator overlappersWithCoherence = this.evidenceOverlapChecker.overlappersWithCoherence(breakpointEvidence);
        while (overlappersWithCoherence.hasNext()) {
            ImmutablePair<BreakpointEvidence, Boolean> next = overlappersWithCoherence.next();
            BreakpointEvidence breakpointEvidence2 = (BreakpointEvidence) next.left;
            if (!breakpointEvidence2.equals(breakpointEvidence)) {
                i++;
                int mappingQualityForOverlap = getMappingQualityForOverlap(breakpointEvidence2);
                i2 += mappingQualityForOverlap;
                if (((Boolean) next.right).booleanValue()) {
                    i3++;
                    i4 += mappingQualityForOverlap;
                }
            }
        }
        this.rawFeatureCache.put(breakpointEvidence, new UnscaledOverlapInfo(i, i3, i2, i4));
    }

    private static double getGenomeIntervalsOverlap(BreakpointEvidence breakpointEvidence, FeatureDataSource<BEDFeature> featureDataSource, ReadMetadata readMetadata) {
        SVInterval location = breakpointEvidence.getLocation();
        SimpleInterval simpleInterval = new SimpleInterval(readMetadata.getContigName(location.getContig()), location.getStart(), location.getEnd() - 1);
        int i = 0;
        Iterator<BEDFeature> query = featureDataSource.query(simpleInterval);
        while (query.hasNext()) {
            BEDFeature next = query.next();
            i += (Math.min(simpleInterval.getEnd(), next.getEnd()) + 1) - Math.max(simpleInterval.getStart(), next.getStart());
        }
        return i / simpleInterval.size();
    }
}
