package org.broadinstitute.hellbender.tools.walkers.vqsr;

import Jama.Matrix;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.apache.commons.math3.special.Gamma;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.MathUtils;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/MultivariateGaussian.class */
public class MultivariateGaussian {
    public double pMixtureLog10;
    public double sumProb;
    public final double[] mu;
    public final Matrix sigma;
    public double hyperParameter_a;
    public double hyperParameter_b;
    public double hyperParameter_lambda;
    private double cachedDenomLog10;
    private Matrix cachedSigmaInverse;
    private final double[] pVarInGaussian;
    int pVarInGaussianIndex = 0;

    public MultivariateGaussian(int i, int i2) {
        this.mu = new double[i2];
        this.sigma = new Matrix(i2, i2);
        this.pVarInGaussian = new double[i];
    }

    public void zeroOutMu() {
        Arrays.fill(this.mu, 0.0d);
    }

    public void zeroOutSigma() {
        double[][] dArr = new double[this.mu.length][this.mu.length];
        for (double[] dArr2 : dArr) {
            Arrays.fill(dArr2, 0.0d);
        }
        this.sigma.setMatrix(0, this.mu.length - 1, 0, this.mu.length - 1, new Matrix(dArr));
    }

    public void initializeRandomMu(Random random) {
        for (int i = 0; i < this.mu.length; i++) {
            this.mu[i] = (-4.0d) + (8.0d * random.nextDouble());
        }
    }

    public void initializeRandomSigma(Random random) {
        double[][] dArr = new double[this.mu.length][this.mu.length];
        for (int i = 0; i < this.mu.length; i++) {
            for (int i2 = i; i2 < this.mu.length; i2++) {
                dArr[i2][i] = 0.55d + (1.25d * random.nextDouble());
                if (random.nextBoolean()) {
                    double[] dArr2 = dArr[i2];
                    int i3 = i;
                    dArr2[i3] = dArr2[i3] * (-1.0d);
                }
                if (i != i2) {
                    dArr[i][i2] = 0.0d;
                }
            }
        }
        Matrix matrix = new Matrix(dArr);
        this.sigma.setMatrix(0, this.mu.length - 1, 0, this.mu.length - 1, matrix.times(matrix.transpose()));
    }

    public double calculateDistanceFromMeanSquared(VariantDatum variantDatum) {
        return MathUtils.distanceSquared(variantDatum.annotations, this.mu);
    }

    public void incrementMu(VariantDatum variantDatum) {
        incrementMu(variantDatum, 1.0d);
    }

    public void incrementMu(VariantDatum variantDatum, double d) {
        for (int i = 0; i < this.mu.length; i++) {
            double[] dArr = this.mu;
            int i2 = i;
            dArr[i2] = dArr[i2] + (d * variantDatum.annotations[i]);
        }
    }

    public void divideEqualsMu(double d) {
        for (int i = 0; i < this.mu.length; i++) {
            double[] dArr = this.mu;
            int i2 = i;
            dArr[i2] = dArr[i2] / d;
        }
    }

    private void precomputeInverse() {
        try {
            this.cachedSigmaInverse = this.sigma.inverse();
        } catch (Exception e) {
            throw new UserException("Error during clustering. Most likely there are too few variants used during Gaussian mixture modeling. Please consider raising the number of variants used to train the negative model (via --percentBadVariants 0.05, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example).", e);
        }
    }

    public void precomputeDenominatorForEvaluation() {
        precomputeInverse();
        this.cachedDenomLog10 = Math.log10(Math.pow(6.283185307179586d, ((-1.0d) * this.mu.length) / 2.0d)) + Math.log10(Math.pow(this.sigma.det(), -0.5d));
    }

    public void precomputeDenominatorForVariationalBayes(double d) {
        precomputeInverse();
        this.cachedSigmaInverse.timesEquals(this.hyperParameter_a);
        double d2 = 0.0d;
        for (int i = 1; i <= this.mu.length; i++) {
            d2 += Gamma.digamma(((this.hyperParameter_a + 1.0d) - i) / 2.0d);
        }
        this.cachedDenomLog10 = ((Gamma.digamma(this.hyperParameter_lambda) - Gamma.digamma(d)) / Math.log(10.0d)) + ((0.5d * ((d2 - Math.log(this.sigma.det())) + (Math.log(2.0d) * this.mu.length))) / Math.log(10.0d)) + ((((-1.0d) * this.mu.length) / (2.0d * this.hyperParameter_b)) / Math.log(10.0d));
    }

    public double evaluateDatumLog10(VariantDatum variantDatum) {
        double d = 0.0d;
        double[] dArr = new double[this.mu.length];
        Arrays.fill(dArr, 0.0d);
        for (int i = 0; i < this.mu.length; i++) {
            for (int i2 = 0; i2 < this.mu.length; i2++) {
                int i3 = i;
                dArr[i3] = dArr[i3] + ((variantDatum.annotations[i2] - this.mu[i2]) * this.cachedSigmaInverse.get(i2, i));
            }
        }
        for (int i4 = 0; i4 < this.mu.length; i4++) {
            d += dArr[i4] * (variantDatum.annotations[i4] - this.mu[i4]);
        }
        return (((-0.5d) * d) / Math.log(10.0d)) + this.cachedDenomLog10;
    }

    public void assignPVarInGaussian(double d) {
        double[] dArr = this.pVarInGaussian;
        int i = this.pVarInGaussianIndex;
        this.pVarInGaussianIndex = i + 1;
        dArr[i] = d;
    }

    public void resetPVarInGaussian() {
        Arrays.fill(this.pVarInGaussian, 0.0d);
        this.pVarInGaussianIndex = 0;
    }

    public void maximizeGaussian(List<VariantDatum> list, double[] dArr, Matrix matrix, double d, double d2, double d3) {
        this.sumProb = 1.0E-10d;
        Matrix matrix2 = new Matrix(this.mu.length, this.mu.length);
        zeroOutMu();
        zeroOutSigma();
        int i = 0;
        for (VariantDatum variantDatum : list) {
            int i2 = i;
            i++;
            double d4 = this.pVarInGaussian[i2];
            this.sumProb += d4;
            incrementMu(variantDatum, d4);
        }
        divideEqualsMu(this.sumProb);
        double d5 = (d * this.sumProb) / (d + this.sumProb);
        for (int i3 = 0; i3 < this.mu.length; i3++) {
            double d6 = d5 * (this.mu[i3] - dArr[i3]);
            for (int i4 = 0; i4 < this.mu.length; i4++) {
                matrix2.set(i3, i4, d6 * (this.mu[i4] - dArr[i4]));
            }
        }
        int i5 = 0;
        Matrix matrix3 = new Matrix(this.mu.length, this.mu.length);
        for (VariantDatum variantDatum2 : list) {
            int i6 = i5;
            i5++;
            double d7 = this.pVarInGaussian[i6];
            for (int i7 = 0; i7 < this.mu.length; i7++) {
                double d8 = d7 * (variantDatum2.annotations[i7] - this.mu[i7]);
                for (int i8 = 0; i8 < this.mu.length; i8++) {
                    matrix3.set(i7, i8, d8 * (variantDatum2.annotations[i8] - this.mu[i8]));
                }
            }
            this.sigma.plusEquals(matrix3);
        }
        this.sigma.plusEquals(matrix);
        this.sigma.plusEquals(matrix2);
        for (int i9 = 0; i9 < this.mu.length; i9++) {
            this.mu[i9] = ((this.sumProb * this.mu[i9]) + (d * dArr[i9])) / (this.sumProb + d);
        }
        this.hyperParameter_a = this.sumProb + d3;
        this.hyperParameter_b = this.sumProb + d;
        this.hyperParameter_lambda = this.sumProb + d2;
        resetPVarInGaussian();
    }

    public void evaluateFinalModelParameters(List<VariantDatum> list) {
        this.sumProb = 0.0d;
        zeroOutMu();
        zeroOutSigma();
        int i = 0;
        for (VariantDatum variantDatum : list) {
            int i2 = i;
            i++;
            double d = this.pVarInGaussian[i2];
            this.sumProb += d;
            incrementMu(variantDatum, d);
        }
        divideEqualsMu(this.sumProb);
        int i3 = 0;
        Matrix matrix = new Matrix(this.mu.length, this.mu.length);
        for (VariantDatum variantDatum2 : list) {
            int i4 = i3;
            i3++;
            double d2 = this.pVarInGaussian[i4];
            for (int i5 = 0; i5 < this.mu.length; i5++) {
                for (int i6 = 0; i6 < this.mu.length; i6++) {
                    matrix.set(i5, i6, d2 * (variantDatum2.annotations[i5] - this.mu[i5]) * (variantDatum2.annotations[i6] - this.mu[i6]));
                }
            }
            this.sigma.plusEquals(matrix);
        }
        this.sigma.timesEquals(1.0d / this.sumProb);
        resetPVarInGaussian();
    }
}
