package org.broadinstitute.hellbender.tools.walkers.readorientation;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.util.Histogram;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.MathArrays;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.walkers.validation.basicshortmutpileup.BetaBinomialDistribution;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.Nucleotide;
import org.broadinstitute.hellbender.utils.Utils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/readorientation/LearnReadOrientationModelEngine.class */
public class LearnReadOrientationModelEngine {
    private final double convergenceThreshold;
    private final int maxEMIterations;
    private final String referenceContext;
    private final Nucleotide refAllele;
    private final Histogram<Integer> refHistogram;
    private final List<Histogram<Integer>> altDepthOneHistograms;
    private final List<AltSiteRecord> altDesignMatrix;
    private final RealMatrix altResponsibilities;
    private final Map<Triple<Integer, Nucleotide, ReadOrientation>, double[]> responsibilitiesOfAltDepth1Sites;
    private final RealMatrix refResponsibilities;
    private final int numAltExamples;
    private final int numRefExamples;
    private final int numExamples;
    private static final Map<ArtifactState, BetaDistributionShape> alleleFractionPseudoCounts = getPseudoCountsForAlleleFraction();
    private static final Map<ArtifactState, BetaDistributionShape> altF1R2FractionPseudoCounts = getPseudoCountsForAltF1R2Fraction();
    private final Logger logger;
    private int maxDepth;
    private static final double ALT_PSEUDOCOUNT = 1.0d;
    private static final double REF_PSEUDOCOUNT = 9.0d;
    private static final double PSEUDOCOUNT_OF_HOM_LIKELY = 10000.0d;
    private static final double PSEUDOCOUNT_OF_HOM_UNLIKELY = 3.0d;
    private static final double BALANCED_HET_PSEUDOCOUNT = 5.0d;
    private static final double BALANCED_F1R2_PRIOR = 10.0d;
    private static final double PSEUDOCOUNT_OF_SOMATIC_ALT = 2.0d;
    private static final double PSEUDOCOUNT_OF_SOMATIC_REF = 5.0d;
    private static final double PSEUDOCOUNT_OF_LIKELY_OUTCOME = 100.0d;
    private static final double PSEUDOCOUNT_OF_RARE_OUTCOME = 1.0d;
    private RealVector effectiveCounts = new ArrayRealVector(F1R2FilterConstants.NUM_STATES);
    private final MutableInt numIterations = new MutableInt();

    public LearnReadOrientationModelEngine(Histogram<Integer> histogram, List<Histogram<Integer>> list, List<AltSiteRecord> list2, double d, int i, int i2, Logger logger) {
        this.refHistogram = (Histogram) Utils.nonNull(histogram);
        this.altDepthOneHistograms = (List) Utils.nonNull(list);
        this.altDesignMatrix = (List) Utils.nonNull(list2);
        this.referenceContext = histogram.getValueLabel();
        Utils.validate(this.referenceContext.length() == 3, String.format("reference context must have length %d but got %s", 3, this.referenceContext));
        Utils.validate(F1R2FilterConstants.CANONICAL_KMERS.contains(this.referenceContext), this.referenceContext + " is not in the set of canonical kmers");
        this.numAltExamples = this.altDesignMatrix.size() + list.stream().mapToInt(histogram2 -> {
            return (int) histogram2.getSumOfValues();
        }).sum();
        this.numRefExamples = (int) histogram.getSumOfValues();
        this.numExamples = this.numAltExamples + this.numRefExamples;
        this.refResponsibilities = new Array2DRowRealMatrix(i2, F1R2FilterConstants.NUM_STATES);
        this.altResponsibilities = new Array2DRowRealMatrix(this.altDesignMatrix.size(), F1R2FilterConstants.NUM_STATES);
        this.responsibilitiesOfAltDepth1Sites = new HashMap();
        this.refAllele = F1R2FilterUtils.getMiddleBase(this.referenceContext);
        this.convergenceThreshold = d;
        this.maxEMIterations = i;
        this.maxDepth = i2;
        this.logger = logger;
    }

    public ArtifactPrior learnPriorForArtifactStates() {
        double[] flatPrior = getFlatPrior(this.refAllele);
        double[] copyOf = Arrays.copyOf(flatPrior, F1R2FilterConstants.NUM_STATES);
        do {
            double[] copyOf2 = Arrays.copyOf(copyOf, F1R2FilterConstants.NUM_STATES);
            takeEstep(copyOf);
            copyOf = takeMstep(flatPrior);
            double distance = MathArrays.distance(copyOf2, copyOf);
            this.numIterations.increment();
            if (distance <= this.convergenceThreshold) {
                break;
            }
        } while (this.numIterations.intValue() < this.maxEMIterations);
        if (this.numIterations.intValue() == this.maxEMIterations) {
            this.logger.info(String.format("Context %s: with %s ref and %s alt examples, EM failed to converge within %d steps", this.referenceContext, Integer.valueOf(this.numRefExamples), Integer.valueOf(this.numAltExamples), Integer.valueOf(this.maxEMIterations)));
        } else {
            this.logger.info(String.format("Context %s: with %s ref and %s alt examples, EM converged in %d steps", this.referenceContext, Integer.valueOf(this.numRefExamples), Integer.valueOf(this.numAltExamples), Integer.valueOf(this.numIterations.intValue())));
        }
        return new ArtifactPrior(this.referenceContext, copyOf, this.numExamples, this.numAltExamples);
    }

    private void takeEstep(double[] dArr) {
        for (int i = 0; i < this.maxDepth; i++) {
            this.refResponsibilities.setRow(i, computeResponsibilities(this.refAllele, this.refAllele, 0, 0, i + 1, dArr, false));
        }
        for (int i2 = 0; i2 < this.altDesignMatrix.size(); i2++) {
            AltSiteRecord altSiteRecord = this.altDesignMatrix.get(i2);
            this.altResponsibilities.setRow(i2, computeResponsibilities(this.refAllele, altSiteRecord.getAltAllele(), altSiteRecord.getAltCount(), altSiteRecord.getAltF1R2(), altSiteRecord.getDepth(), dArr, false));
        }
        for (int i3 = 0; i3 < this.maxDepth; i3++) {
            int i4 = i3 + 1;
            for (Nucleotide nucleotide : Nucleotide.STANDARD_BASES) {
                ReadOrientation[] values = ReadOrientation.values();
                int length = values.length;
                for (int i5 = 0; i5 < length; i5++) {
                    ReadOrientation readOrientation = values[i5];
                    if (nucleotide != this.refAllele) {
                        this.responsibilitiesOfAltDepth1Sites.put(createKey(i4, nucleotide, readOrientation), computeResponsibilities(this.refAllele, nucleotide, 1, readOrientation == ReadOrientation.F1R2 ? 1 : 0, i4, dArr, false));
                    }
                }
            }
        }
    }

    private double[] takeMstep(double[] dArr) {
        double[] sumArrayFunction = MathUtils.sumArrayFunction(0, this.altDesignMatrix.size(), i -> {
            return this.altResponsibilities.getRow(i);
        });
        double[] dArr2 = new double[F1R2FilterConstants.NUM_STATES];
        for (Histogram<Integer> histogram : this.altDepthOneHistograms) {
            Triple<String, Nucleotide, ReadOrientation> labelToTriplet = F1R2FilterUtils.labelToTriplet(histogram.getValueLabel());
            Nucleotide nucleotide = (Nucleotide) labelToTriplet.getMiddle();
            ReadOrientation readOrientation = (ReadOrientation) labelToTriplet.getRight();
            dArr2 = MathArrays.ebeAdd(dArr2, MathUtils.sumArrayFunction(0, this.maxDepth, i2 -> {
                return MathArrays.scale(histogram.get(Integer.valueOf(i2 + 1)).getValue(), this.responsibilitiesOfAltDepth1Sites.get(createKey(i2 + 1, nucleotide, readOrientation)));
            }));
        }
        this.effectiveCounts = new ArrayRealVector(MathArrays.ebeAdd(MathArrays.ebeAdd(sumArrayFunction, dArr2), MathUtils.sumArrayFunction(0, this.maxDepth, i3 -> {
            return MathArrays.scale(this.refHistogram.get(Integer.valueOf(i3 + 1)).getValue(), this.refResponsibilities.getRow(i3));
        })));
        return MathUtils.normalizeSumToOne(this.effectiveCounts.add(new ArrayRealVector(dArr)).toArray());
    }

    public static double[] computeResponsibilities(Nucleotide nucleotide, Nucleotide nucleotide2, int i, int i2, int i3, double[] dArr, boolean z) {
        double[] dArr2 = new double[F1R2FilterConstants.NUM_STATES];
        List<ArtifactState> refToRefArtifacts = ArtifactState.getRefToRefArtifacts(nucleotide);
        for (ArtifactState artifactState : ArtifactState.values()) {
            int ordinal = artifactState.ordinal();
            if (refToRefArtifacts.contains(artifactState)) {
                dArr2[ordinal] = Double.NEGATIVE_INFINITY;
            } else if (!ArtifactState.artifactStates.contains(artifactState) || artifactState.getAltAlleleOfArtifact() == nucleotide2) {
                dArr2[ordinal] = computeLogPosterior(i, i2, i3, dArr[ordinal], alleleFractionPseudoCounts.get(artifactState), altF1R2FractionPseudoCounts.get(artifactState));
            } else {
                dArr2[ordinal] = Double.NEGATIVE_INFINITY;
            }
        }
        if (z) {
            dArr2[ArtifactState.HOM_REF.ordinal()] = Double.NEGATIVE_INFINITY;
        }
        return NaturalLogUtils.normalizeFromLogToLinearSpace(dArr2);
    }

    private static double computeLogPosterior(int i, int i2, int i3, double d, BetaDistributionShape betaDistributionShape, BetaDistributionShape betaDistributionShape2) {
        Utils.validateArg(MathUtils.isValidProbability(d), String.format("statePrior must be a probability but got %f", Double.valueOf(d)));
        return Math.log(d) + new BetaBinomialDistribution(null, betaDistributionShape.getAlpha(), betaDistributionShape.getBeta(), i3).logProbability(i) + new BetaBinomialDistribution(null, betaDistributionShape2.getAlpha(), betaDistributionShape2.getBeta(), i).logProbability(i2);
    }

    private static Map<ArtifactState, BetaDistributionShape> getPseudoCountsForAlleleFraction() {
        HashMap hashMap = new HashMap(ArtifactState.values().length);
        ArtifactState.getF1R2ArtifactStates().forEach(artifactState -> {
            hashMap.put(artifactState, new BetaDistributionShape(1.0d, REF_PSEUDOCOUNT));
        });
        ArtifactState.getF2R1ArtifactStates().forEach(artifactState2 -> {
            hashMap.put(artifactState2, new BetaDistributionShape(1.0d, REF_PSEUDOCOUNT));
        });
        hashMap.put(ArtifactState.HOM_REF, new BetaDistributionShape(3.0d, PSEUDOCOUNT_OF_HOM_LIKELY));
        hashMap.put(ArtifactState.GERMLINE_HET, new BetaDistributionShape(5.0d, 5.0d));
        hashMap.put(ArtifactState.SOMATIC_HET, new BetaDistributionShape(2.0d, 5.0d));
        hashMap.put(ArtifactState.HOM_VAR, new BetaDistributionShape(PSEUDOCOUNT_OF_HOM_LIKELY, 3.0d));
        return hashMap;
    }

    private static Map<ArtifactState, BetaDistributionShape> getPseudoCountsForAltF1R2Fraction() {
        HashMap hashMap = new HashMap(ArtifactState.values().length);
        ArtifactState.getF1R2ArtifactStates().forEach(artifactState -> {
            hashMap.put(artifactState, new BetaDistributionShape(PSEUDOCOUNT_OF_LIKELY_OUTCOME, 1.0d));
        });
        ArtifactState.getF2R1ArtifactStates().forEach(artifactState2 -> {
            hashMap.put(artifactState2, new BetaDistributionShape(1.0d, PSEUDOCOUNT_OF_LIKELY_OUTCOME));
        });
        ArtifactState.getNonArtifactStates().forEach(artifactState3 -> {
            hashMap.put(artifactState3, new BetaDistributionShape(10.0d, 10.0d));
        });
        return hashMap;
    }

    @VisibleForTesting
    public double[] getRefResonsibilities(int i) {
        return this.refResponsibilities.getRow(i);
    }

    @VisibleForTesting
    public double[] getAltResonsibilities(int i) {
        return this.altResponsibilities.getRow(i);
    }

    @VisibleForTesting
    public double[] getAltDepth1Resonsibilities(int i) {
        return null;
    }

    @VisibleForTesting
    public RealVector getEffectiveCounts() {
        return this.effectiveCounts;
    }

    @VisibleForTesting
    public double getEffectiveCounts(ArtifactState artifactState) {
        return this.effectiveCounts.getEntry(artifactState.ordinal());
    }

    public static double[] getFlatPrior(Nucleotide nucleotide) {
        List<ArtifactState> refToRefArtifacts = ArtifactState.getRefToRefArtifacts(nucleotide);
        double[] dArr = new double[F1R2FilterConstants.NUM_STATES];
        Arrays.fill(dArr, 1.0d / (F1R2FilterConstants.NUM_STATES - refToRefArtifacts.size()));
        Iterator<ArtifactState> it = refToRefArtifacts.iterator();
        while (it.hasNext()) {
            dArr[it.next().ordinal()] = 0.0d;
        }
        return dArr;
    }

    private Triple<Integer, Nucleotide, ReadOrientation> createKey(int i, Nucleotide nucleotide, ReadOrientation readOrientation) {
        return new ImmutableTriple(Integer.valueOf(i), nucleotide, readOrientation);
    }
}
