package org.broadinstitute.hellbender.tools.walkers.vqsr;

import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.cmdline.ExomeStandardArgumentDefinitions;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.vqsr.VariantRecalibratorArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.vqsr.scalable.ScoreVariantAnnotations;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.variant.GATKVCFConstants;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/VariantDataManager.class */
public class VariantDataManager {
    private List<VariantDatum> data;
    private double[] meanVector;
    private double[] varianceVector;
    public List<String> annotationKeys;
    private final VariantRecalibratorArgumentCollection VRAC;
    protected static final Logger logger = LogManager.getLogger(VariantDataManager.class);
    protected final List<TrainingSet> trainingSets;
    private static final double SAFETY_OFFSET = 0.01d;
    private static final double PRECISION = 0.01d;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.broadinstitute.hellbender.tools.walkers.vqsr.VariantDataManager$1, reason: invalid class name */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/VariantDataManager$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$htsjdk$variant$variantcontext$VariantContext$Type;

        static {
            try {
                $SwitchMap$org$broadinstitute$hellbender$tools$walkers$vqsr$VariantRecalibratorArgumentCollection$Mode[VariantRecalibratorArgumentCollection.Mode.SNP.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$broadinstitute$hellbender$tools$walkers$vqsr$VariantRecalibratorArgumentCollection$Mode[VariantRecalibratorArgumentCollection.Mode.INDEL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$broadinstitute$hellbender$tools$walkers$vqsr$VariantRecalibratorArgumentCollection$Mode[VariantRecalibratorArgumentCollection.Mode.BOTH.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$htsjdk$variant$variantcontext$VariantContext$Type = new int[VariantContext.Type.values().length];
            try {
                $SwitchMap$htsjdk$variant$variantcontext$VariantContext$Type[VariantContext.Type.SNP.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$htsjdk$variant$variantcontext$VariantContext$Type[VariantContext.Type.MNP.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$htsjdk$variant$variantcontext$VariantContext$Type[VariantContext.Type.INDEL.ordinal()] = 3;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$htsjdk$variant$variantcontext$VariantContext$Type[VariantContext.Type.MIXED.ordinal()] = 4;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$htsjdk$variant$variantcontext$VariantContext$Type[VariantContext.Type.SYMBOLIC.ordinal()] = 5;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/VariantDataManager$MyDoubleForSorting.class */
    public class MyDoubleForSorting implements Comparable<MyDoubleForSorting> {
        final Double myData;
        final int originalIndex;

        public MyDoubleForSorting(double d, int i) {
            this.myData = Double.valueOf(d);
            this.originalIndex = i;
        }

        @Override // java.lang.Comparable
        public int compareTo(MyDoubleForSorting myDoubleForSorting) {
            return this.myData.compareTo(myDoubleForSorting.myData);
        }
    }

    public VariantDataManager(List<String> list, VariantRecalibratorArgumentCollection variantRecalibratorArgumentCollection) {
        this.data = Collections.emptyList();
        this.data = Collections.emptyList();
        List list2 = (List) list.stream().distinct().collect(Collectors.toList());
        if (list.size() != list2.size()) {
            logger.warn("Ignoring duplicate annotations for recalibration {}", Utils.getDuplicatedItems(list));
        }
        this.annotationKeys = new ArrayList(list2);
        this.VRAC = variantRecalibratorArgumentCollection;
        this.meanVector = new double[this.annotationKeys.size()];
        this.varianceVector = new double[this.annotationKeys.size()];
        this.trainingSets = new ArrayList();
    }

    public void setData(List<VariantDatum> list) {
        this.data = list;
    }

    public void setNormalization(Map<String, Double> map, Map<String, Double> map2) {
        for (int i = 0; i < this.annotationKeys.size(); i++) {
            this.meanVector[i] = map.get(this.annotationKeys.get(i)).doubleValue();
            this.varianceVector[i] = map2.get(this.annotationKeys.get(i)).doubleValue();
        }
    }

    public List<VariantDatum> getData() {
        return this.data;
    }

    public void normalizeData(boolean z, List<Integer> list) {
        double d;
        double d2;
        boolean z2 = false;
        for (int i = 0; i < this.meanVector.length; i++) {
            if (z) {
                d = mean(i, true);
                d2 = standardDeviation(d, i, true);
                if (Double.isNaN(d)) {
                    throw new UserException.BadInput("Values for " + this.annotationKeys.get(i) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
                }
                z2 = z2 || d2 < 1.0E-5d;
                this.meanVector[i] = d;
                this.varianceVector[i] = d2;
            } else {
                d = this.meanVector[i];
                d2 = this.varianceVector[i];
            }
            logger.info(this.annotationKeys.get(i) + String.format(": \t mean = %.2f\t standard deviation = %.2f", Double.valueOf(d), Double.valueOf(d2)));
            for (VariantDatum variantDatum : this.data) {
                variantDatum.annotations[i] = variantDatum.isNull[i] ? 0.1d * Utils.getRandomGenerator().nextGaussian() : (variantDatum.annotations[i] - d) / d2;
            }
        }
        if (z2) {
            throw new UserException.BadInput("Found annotations with zero variance. They must be excluded before proceeding.");
        }
        for (VariantDatum variantDatum2 : this.data) {
            boolean z3 = false;
            for (double d3 : variantDatum2.annotations) {
                z3 = z3 || Math.abs(d3) > this.VRAC.STD_THRESHOLD;
            }
            variantDatum2.failingSTDThreshold = z3;
        }
        if (list == null) {
            list = calculateSortOrder(this.meanVector);
        }
        this.annotationKeys = reorderList(this.annotationKeys, list);
        this.varianceVector = ArrayUtils.toPrimitive((Double[]) reorderArray(ArrayUtils.toObject(this.varianceVector), list));
        this.meanVector = ArrayUtils.toPrimitive((Double[]) reorderArray(ArrayUtils.toObject(this.meanVector), list));
        for (VariantDatum variantDatum3 : this.data) {
            variantDatum3.annotations = ArrayUtils.toPrimitive((Double[]) reorderArray(ArrayUtils.toObject(variantDatum3.annotations), list));
            variantDatum3.isNull = ArrayUtils.toPrimitive((Boolean[]) reorderArray(ArrayUtils.toObject(variantDatum3.isNull), list));
        }
        logger.info("Annotation order is: " + this.annotationKeys.toString());
    }

    public double[] getMeanVector() {
        return this.meanVector;
    }

    public double[] getVarianceVector() {
        return this.varianceVector;
    }

    protected List<Integer> calculateSortOrder(double[] dArr) {
        ArrayList arrayList = new ArrayList(dArr.length);
        ArrayList arrayList2 = new ArrayList(dArr.length);
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i;
            i++;
            arrayList2.add(new MyDoubleForSorting((-1.0d) * Math.abs(dArr[i2] - mean(i2, false)), i3));
        }
        Collections.sort(arrayList2);
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            arrayList.add(Integer.valueOf(((MyDoubleForSorting) it.next()).originalIndex));
        }
        return arrayList;
    }

    private <T> T[] reorderArray(T[] tArr, List<Integer> list) {
        return (T[]) reorderList(Arrays.asList(tArr), list).toArray(tArr);
    }

    private <T> List<T> reorderList(List<T> list, List<Integer> list2) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<Integer> it = list2.iterator();
        while (it.hasNext()) {
            arrayList.add(list.get(it.next().intValue()));
        }
        return arrayList;
    }

    public double denormalizeDatum(double d, int i) {
        return (d * this.varianceVector[i]) + this.meanVector[i];
    }

    public void addTrainingSet(TrainingSet trainingSet) {
        this.trainingSets.add(trainingSet);
    }

    public List<String> getAnnotationKeys() {
        return this.annotationKeys;
    }

    public boolean checkHasTrainingSet() {
        Iterator<TrainingSet> it = this.trainingSets.iterator();
        while (it.hasNext()) {
            if (it.next().isTraining) {
                return true;
            }
        }
        return false;
    }

    public boolean checkHasTruthSet() {
        Iterator<TrainingSet> it = this.trainingSets.iterator();
        while (it.hasNext()) {
            if (it.next().isTruth) {
                return true;
            }
        }
        return false;
    }

    public List<VariantDatum> getTrainingData() {
        ArrayList arrayList = new ArrayList();
        for (VariantDatum variantDatum : this.data) {
            if (variantDatum.atTrainingSite && !variantDatum.failingSTDThreshold) {
                arrayList.add(variantDatum);
            } else if (variantDatum.failingSTDThreshold && this.VRAC.debugStdevThresholding) {
                logger.warn("Datum at " + variantDatum.loc + " with ref " + variantDatum.referenceAllele + " and alt " + variantDatum.alternateAllele + " failing std thresholding: " + Arrays.toString(variantDatum.annotations));
            }
        }
        logger.info("Training with " + arrayList.size() + " variants after standard deviation thresholding.");
        if (arrayList.size() < this.VRAC.MIN_NUM_BAD_VARIANTS) {
            logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable.");
        } else if (arrayList.size() > this.VRAC.MAX_NUM_TRAINING_DATA) {
            logger.warn("WARNING: Very large training set detected. Downsampling to " + this.VRAC.MAX_NUM_TRAINING_DATA + " training variants.");
            Collections.shuffle(arrayList, Utils.getRandomGenerator());
            return arrayList.subList(0, this.VRAC.MAX_NUM_TRAINING_DATA);
        }
        return arrayList;
    }

    public List<VariantDatum> selectWorstVariants() {
        ArrayList arrayList = new ArrayList();
        for (VariantDatum variantDatum : this.data) {
            if (variantDatum != null && !variantDatum.failingSTDThreshold && !Double.isInfinite(variantDatum.lod) && variantDatum.lod < this.VRAC.BAD_LOD_CUTOFF) {
                variantDatum.atAntiTrainingSite = true;
                arrayList.add(variantDatum);
            }
        }
        logger.info("Selected worst " + arrayList.size() + " scoring variants --> variants with LOD <= " + String.format(ScoreVariantAnnotations.DEFAULT_DOUBLE_FORMAT, Double.valueOf(this.VRAC.BAD_LOD_CUTOFF)) + ".");
        return arrayList;
    }

    public List<VariantDatum> getEvaluationData() {
        ArrayList arrayList = new ArrayList();
        for (VariantDatum variantDatum : this.data) {
            if (variantDatum != null && !variantDatum.failingSTDThreshold && !variantDatum.atTrainingSite && !variantDatum.atAntiTrainingSite) {
                arrayList.add(variantDatum);
            }
        }
        return arrayList;
    }

    public void dropAggregateData() {
        Iterator<VariantDatum> it = this.data.iterator();
        while (it.hasNext()) {
            if (it.next().isAggregate) {
                it.remove();
            }
        }
    }

    public List<VariantDatum> getRandomDataForPlotting(int i, List<VariantDatum> list, List<VariantDatum> list2, List<VariantDatum> list3) {
        ArrayList arrayList = new ArrayList();
        Collections.shuffle(list, Utils.getRandomGenerator());
        Collections.shuffle(list2, Utils.getRandomGenerator());
        Collections.shuffle(list3, Utils.getRandomGenerator());
        arrayList.addAll(list.subList(0, Math.min(i, list.size())));
        arrayList.addAll(list2.subList(0, Math.min(i, list2.size())));
        arrayList.addAll(list3.subList(0, Math.min(i, list3.size())));
        Collections.shuffle(arrayList, Utils.getRandomGenerator());
        return arrayList;
    }

    protected double mean(int i, boolean z) {
        double d = 0.0d;
        int i2 = 0;
        for (VariantDatum variantDatum : this.data) {
            if (z == variantDatum.atTrainingSite && !variantDatum.isNull[i]) {
                d += variantDatum.annotations[i];
                i2++;
            }
        }
        return d / i2;
    }

    protected double standardDeviation(double d, int i, boolean z) {
        double d2 = 0.0d;
        int i2 = 0;
        for (VariantDatum variantDatum : this.data) {
            if (z == variantDatum.atTrainingSite && !variantDatum.isNull[i]) {
                d2 += (variantDatum.annotations[i] - d) * (variantDatum.annotations[i] - d);
                i2++;
            }
        }
        return Math.sqrt(d2 / i2);
    }

    public void decodeAnnotations(VariantDatum variantDatum, VariantContext variantContext, boolean z) {
        double[] dArr = new double[this.annotationKeys.size()];
        boolean[] zArr = new boolean[this.annotationKeys.size()];
        int i = 0;
        for (String str : this.annotationKeys) {
            zArr[i] = false;
            dArr[i] = decodeAnnotation(str, variantContext, z, this.VRAC, variantDatum);
            if (Double.isNaN(dArr[i])) {
                zArr[i] = true;
            }
            i++;
        }
        variantDatum.annotations = dArr;
        variantDatum.isNull = zArr;
    }

    private static double logitTransform(double d, double d2, double d3) {
        return Math.log((d - d2) / (d3 - d));
    }

    private static double decodeAnnotation(String str, VariantContext variantContext, boolean z, VariantRecalibratorArgumentCollection variantRecalibratorArgumentCollection, VariantDatum variantDatum) {
        double d;
        try {
            if (variantRecalibratorArgumentCollection.useASannotations && str.startsWith("AS_")) {
                List attributeAsList = variantContext.getAttributeAsList(str);
                if (!variantContext.hasAllele(variantDatum.alternateAllele)) {
                    throw new IllegalStateException("VariantDatum allele " + variantDatum.alternateAllele + " is not contained in the input VariantContext.");
                }
                d = Double.parseDouble((String) attributeAsList.get(variantContext.getAlleleIndex(variantDatum.alternateAllele) - 1));
            } else {
                d = variantContext.getAttributeAsDouble(str, Double.NaN);
            }
            if (Double.isInfinite(d)) {
                d = Double.NaN;
            }
            if (z && str.equalsIgnoreCase(GATKVCFConstants.HAPLOTYPE_SCORE_KEY) && MathUtils.compareDoubles(d, 0.0d, 0.01d) == 0) {
                d += 0.01d * Utils.getRandomGenerator().nextGaussian();
            }
            if (z && ((str.equalsIgnoreCase(GATKVCFConstants.FISHER_STRAND_KEY) || str.equalsIgnoreCase(GATKVCFConstants.AS_FILTER_STATUS_KEY)) && MathUtils.compareDoubles(d, 0.0d, 0.01d) == 0)) {
                d += 0.01d * Utils.getRandomGenerator().nextGaussian();
            }
            if (z && str.equalsIgnoreCase(GATKVCFConstants.INBREEDING_COEFFICIENT_KEY) && MathUtils.compareDoubles(d, 0.0d, 0.01d) == 0) {
                d += 0.01d * Utils.getRandomGenerator().nextGaussian();
            }
            if (z && ((str.equalsIgnoreCase(GATKVCFConstants.STRAND_ODDS_RATIO_KEY) || str.equalsIgnoreCase(GATKVCFConstants.AS_STRAND_ODDS_RATIO_KEY)) && MathUtils.compareDoubles(d, 0.6931472d, 0.01d) == 0)) {
                d += 0.01d * Utils.getRandomGenerator().nextGaussian();
            }
            if (z && str.equalsIgnoreCase("MQ")) {
                if (variantRecalibratorArgumentCollection.MQ_CAP > 0) {
                    d = logitTransform(d, -0.01d, variantRecalibratorArgumentCollection.MQ_CAP + 0.01d);
                    if (MathUtils.compareDoubles(d, logitTransform(variantRecalibratorArgumentCollection.MQ_CAP, -0.01d, variantRecalibratorArgumentCollection.MQ_CAP + 0.01d), 0.01d) == 0) {
                        d += variantRecalibratorArgumentCollection.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
                    }
                } else if (MathUtils.compareDoubles(d, variantRecalibratorArgumentCollection.MQ_CAP, 0.01d) == 0) {
                    d += variantRecalibratorArgumentCollection.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
                }
            }
            if (z && str.equalsIgnoreCase(GATKVCFConstants.AS_RMS_MAPPING_QUALITY_KEY)) {
                d += variantRecalibratorArgumentCollection.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
            }
        } catch (NumberFormatException e) {
            d = Double.NaN;
        }
        return d;
    }

    public void parseTrainingSets(FeatureContext featureContext, VariantContext variantContext, VariantDatum variantDatum, boolean z) {
        variantDatum.isKnown = false;
        variantDatum.atTruthSite = false;
        variantDatum.atTrainingSite = false;
        variantDatum.atAntiTrainingSite = false;
        variantDatum.prior = 2.0d;
        for (TrainingSet trainingSet : this.trainingSets) {
            for (VariantContext variantContext2 : featureContext.getValues(trainingSet.variantSource, featureContext.getInterval().getStart())) {
                if (!this.VRAC.useASannotations || doAllelesMatch(variantContext2, variantDatum)) {
                    if (isValidVariant(variantContext, variantContext2, z)) {
                        variantDatum.isKnown = variantDatum.isKnown || trainingSet.isKnown;
                        variantDatum.atTruthSite = variantDatum.atTruthSite || trainingSet.isTruth;
                        variantDatum.atTrainingSite = variantDatum.atTrainingSite || trainingSet.isTraining;
                        variantDatum.prior = Math.max(variantDatum.prior, trainingSet.prior);
                    }
                    if (variantContext2 != null) {
                        variantDatum.atAntiTrainingSite = variantDatum.atAntiTrainingSite || trainingSet.isAntiTraining;
                    }
                }
            }
        }
    }

    private boolean isValidVariant(VariantContext variantContext, VariantContext variantContext2, boolean z) {
        return variantContext2 != null && variantContext2.isNotFiltered() && variantContext2.isVariant() && checkVariationClass(variantContext, variantContext2) && (z || !variantContext2.hasGenotypes() || variantContext2.isPolymorphicInSamples());
    }

    private boolean doAllelesMatch(VariantContext variantContext, VariantDatum variantDatum) {
        if (variantDatum.alternateAllele == null) {
            return true;
        }
        try {
            return GATKVariantContextUtils.isAlleleInList(variantDatum.referenceAllele, variantDatum.alternateAllele, variantContext.getReference(), variantContext.getAlternateAlleles());
        } catch (IllegalStateException e) {
            throw new IllegalStateException("Reference allele mismatch at position " + variantContext.getContig() + ":" + variantContext.getStart() + " : ", e);
        }
    }

    protected static boolean checkVariationClass(VariantContext variantContext, VariantContext variantContext2) {
        switch (AnonymousClass1.$SwitchMap$htsjdk$variant$variantcontext$VariantContext$Type[variantContext2.getType().ordinal()]) {
            case 1:
            case 2:
                return checkVariationClass(variantContext, VariantRecalibratorArgumentCollection.Mode.SNP);
            case 3:
            case 4:
            case 5:
                return checkVariationClass(variantContext, VariantRecalibratorArgumentCollection.Mode.INDEL);
            default:
                return false;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean checkVariationClass(VariantContext variantContext, VariantRecalibratorArgumentCollection.Mode mode) {
        switch (mode) {
            case SNP:
                return variantContext.isSNP() || variantContext.isMNP();
            case INDEL:
                return variantContext.isStructuralIndel() || variantContext.isIndel() || variantContext.isMixed() || variantContext.isSymbolic();
            case BOTH:
                return true;
            default:
                throw new IllegalStateException("Encountered unknown recal mode: " + mode);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean checkVariationClass(VariantContext variantContext, Allele allele, VariantRecalibratorArgumentCollection.Mode mode) {
        switch (mode) {
            case SNP:
                return variantContext.getReference().length() == allele.length();
            case INDEL:
                return variantContext.getReference().length() != allele.length() || allele.isSymbolic();
            case BOTH:
                return true;
            default:
                throw new IllegalStateException("Encountered unknown recal mode: " + mode);
        }
    }

    public void writeOutRecalibrationTable(VariantContextWriter variantContextWriter, SAMSequenceDictionary sAMSequenceDictionary) {
        Collections.sort(this.data, VariantDatum.getComparator(sAMSequenceDictionary));
        List asList = Arrays.asList(Allele.create(ExomeStandardArgumentDefinitions.NORMAL_BAM_FILE_SHORT_NAME, true), Allele.create("<VQSR>", false));
        for (VariantDatum variantDatum : this.data) {
            if (this.VRAC.useASannotations) {
                asList = Arrays.asList(variantDatum.referenceAllele, variantDatum.alternateAllele);
            }
            VariantContextBuilder variantContextBuilder = new VariantContextBuilder("VQSR", variantDatum.loc.getContig(), variantDatum.loc.getStart(), variantDatum.loc.getEnd(), asList);
            variantContextBuilder.attribute("END", Integer.valueOf(variantDatum.loc.getEnd()));
            variantContextBuilder.attribute(GATKVCFConstants.VQS_LOD_KEY, String.format(ScoreVariantAnnotations.DEFAULT_DOUBLE_FORMAT, Double.valueOf(variantDatum.lod)));
            variantContextBuilder.attribute(GATKVCFConstants.CULPRIT_KEY, variantDatum.worstAnnotation != -1 ? this.annotationKeys.get(variantDatum.worstAnnotation) : "NULL");
            if (variantDatum.atTrainingSite) {
                variantContextBuilder.attribute(GATKVCFConstants.POSITIVE_LABEL_KEY, true);
            }
            if (variantDatum.atAntiTrainingSite) {
                variantContextBuilder.attribute(GATKVCFConstants.NEGATIVE_LABEL_KEY, true);
            }
            variantContextWriter.add(variantContextBuilder.make());
        }
    }
}
