package org.broadinstitute.hellbender.tools.walkers.vqsr;

import Jama.Matrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;

/* loaded from: input_file:org/broadinstitute/hellbender/tools/walkers/vqsr/GaussianMixtureModel.class */
class GaussianMixtureModel {
    protected static final Logger logger = LogManager.getLogger(GaussianMixtureModel.class);
    private final List<MultivariateGaussian> gaussians;
    private final double shrinkage;
    private final double dirichletParameter;
    private final double priorCounts;
    private final double[] empiricalMu;
    private final Matrix empiricalSigma;
    public boolean isModelReadyForEvaluation;
    public boolean failedToConverge = false;

    public GaussianMixtureModel(int i, int i2, int i3, double d, double d2, double d3) {
        this.gaussians = new ArrayList(i);
        for (int i4 = 0; i4 < i; i4++) {
            this.gaussians.add(new MultivariateGaussian(i2, i3));
        }
        this.shrinkage = d;
        this.dirichletParameter = d2;
        this.priorCounts = d3;
        this.empiricalMu = new double[i3];
        this.empiricalSigma = new Matrix(i3, i3);
        this.isModelReadyForEvaluation = false;
        Arrays.fill(this.empiricalMu, 0.0d);
        this.empiricalSigma.setMatrix(0, this.empiricalMu.length - 1, 0, this.empiricalMu.length - 1, Matrix.identity(this.empiricalMu.length, this.empiricalMu.length).times(200.0d).inverse());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public GaussianMixtureModel(List<MultivariateGaussian> list, double d, double d2, double d3) {
        this.gaussians = list;
        int length = list.get(0).mu.length;
        this.shrinkage = d;
        this.dirichletParameter = d2;
        this.priorCounts = d3;
        this.empiricalMu = new double[length];
        this.empiricalSigma = new Matrix(length, length);
        this.isModelReadyForEvaluation = false;
        Arrays.fill(this.empiricalMu, 0.0d);
        this.empiricalSigma.setMatrix(0, this.empiricalMu.length - 1, 0, this.empiricalMu.length - 1, Matrix.identity(this.empiricalMu.length, this.empiricalMu.length).times(200.0d).inverse());
    }

    public void initializeRandomModel(List<VariantDatum> list, int i) {
        Iterator<MultivariateGaussian> it = this.gaussians.iterator();
        while (it.hasNext()) {
            it.next().initializeRandomMu(Utils.getRandomGenerator());
        }
        logger.info("Initializing model with " + i + " k-means iterations...");
        initializeMeansUsingKMeans(list, i);
        for (MultivariateGaussian multivariateGaussian : this.gaussians) {
            multivariateGaussian.pMixtureLog10 = Math.log10(1.0d / this.gaussians.size());
            multivariateGaussian.sumProb = 1.0d / this.gaussians.size();
            multivariateGaussian.initializeRandomSigma(Utils.getRandomGenerator());
            multivariateGaussian.hyperParameter_a = this.priorCounts;
            multivariateGaussian.hyperParameter_b = this.shrinkage;
            multivariateGaussian.hyperParameter_lambda = this.dirichletParameter;
        }
    }

    private void initializeMeansUsingKMeans(List<VariantDatum> list, int i) {
        int i2 = 0;
        while (true) {
            int i3 = i2;
            i2++;
            if (i3 >= i) {
                return;
            }
            for (VariantDatum variantDatum : list) {
                double d = Double.MAX_VALUE;
                MultivariateGaussian multivariateGaussian = null;
                variantDatum.assignment = null;
                for (MultivariateGaussian multivariateGaussian2 : this.gaussians) {
                    double calculateDistanceFromMeanSquared = multivariateGaussian2.calculateDistanceFromMeanSquared(variantDatum);
                    if (calculateDistanceFromMeanSquared < d) {
                        d = calculateDistanceFromMeanSquared;
                        multivariateGaussian = multivariateGaussian2;
                    }
                }
                variantDatum.assignment = multivariateGaussian;
            }
            for (MultivariateGaussian multivariateGaussian3 : this.gaussians) {
                multivariateGaussian3.zeroOutMu();
                int i4 = 0;
                for (VariantDatum variantDatum2 : list) {
                    if (variantDatum2.assignment.equals(multivariateGaussian3)) {
                        i4++;
                        multivariateGaussian3.incrementMu(variantDatum2);
                    }
                }
                if (i4 != 0) {
                    multivariateGaussian3.divideEqualsMu(i4);
                } else {
                    multivariateGaussian3.initializeRandomMu(Utils.getRandomGenerator());
                }
            }
        }
    }

    public void expectationStep(List<VariantDatum> list) {
        Iterator<MultivariateGaussian> it = this.gaussians.iterator();
        while (it.hasNext()) {
            it.next().precomputeDenominatorForVariationalBayes(getSumHyperParameterLambda());
        }
        for (VariantDatum variantDatum : list) {
            double[] normalizeLog10DeleteMePlease = MathUtils.normalizeLog10DeleteMePlease(this.gaussians.stream().mapToDouble(multivariateGaussian -> {
                return multivariateGaussian.evaluateDatumLog10(variantDatum);
            }).toArray(), false);
            int i = 0;
            Iterator<MultivariateGaussian> it2 = this.gaussians.iterator();
            while (it2.hasNext()) {
                int i2 = i;
                i++;
                it2.next().assignPVarInGaussian(normalizeLog10DeleteMePlease[i2]);
            }
        }
    }

    public void maximizationStep(List<VariantDatum> list) {
        this.gaussians.forEach(multivariateGaussian -> {
            multivariateGaussian.maximizeGaussian(list, this.empiricalMu, this.empiricalSigma, this.shrinkage, this.dirichletParameter, this.priorCounts);
        });
    }

    private double getSumHyperParameterLambda() {
        return this.gaussians.stream().mapToDouble(multivariateGaussian -> {
            return multivariateGaussian.hyperParameter_lambda;
        }).sum();
    }

    public void evaluateFinalModelParameters(List<VariantDatum> list) {
        this.gaussians.forEach(multivariateGaussian -> {
            multivariateGaussian.evaluateFinalModelParameters(list);
        });
        normalizePMixtureLog10();
    }

    public double normalizePMixtureLog10() {
        double d = 0.0d;
        double log10 = Math.log10(this.gaussians.stream().mapToDouble(multivariateGaussian -> {
            return multivariateGaussian.sumProb;
        }).sum());
        double[] array = this.gaussians.stream().mapToDouble(multivariateGaussian2 -> {
            return Math.log10(multivariateGaussian2.sumProb) - log10;
        }).toArray();
        MathUtils.normalizeLog10DeleteMePlease(array, true);
        int i = 0;
        for (MultivariateGaussian multivariateGaussian3 : this.gaussians) {
            d += Math.abs(array[i] - multivariateGaussian3.pMixtureLog10);
            int i2 = i;
            i++;
            multivariateGaussian3.pMixtureLog10 = array[i2];
        }
        return d;
    }

    public void precomputeDenominatorForEvaluation() {
        Iterator<MultivariateGaussian> it = this.gaussians.iterator();
        while (it.hasNext()) {
            it.next().precomputeDenominatorForEvaluation();
        }
        this.isModelReadyForEvaluation = true;
    }

    private double nanTolerantLog10SumLog10(double[] dArr) {
        for (double d : dArr) {
            if (Double.isNaN(d)) {
                return Double.NaN;
            }
        }
        return MathUtils.log10sumLog10(dArr);
    }

    public double evaluateDatum(VariantDatum variantDatum) {
        for (boolean z : variantDatum.isNull) {
            if (z) {
                return evaluateDatumMarginalized(variantDatum);
            }
        }
        double[] dArr = new double[this.gaussians.size()];
        int i = 0;
        for (MultivariateGaussian multivariateGaussian : this.gaussians) {
            int i2 = i;
            i++;
            dArr[i2] = multivariateGaussian.pMixtureLog10 + multivariateGaussian.evaluateDatumLog10(variantDatum);
        }
        return nanTolerantLog10SumLog10(dArr);
    }

    public Double evaluateDatumInOneDimension(VariantDatum variantDatum, int i) {
        if (variantDatum.isNull[i]) {
            return null;
        }
        double[] dArr = new double[this.gaussians.size()];
        int i2 = 0;
        for (MultivariateGaussian multivariateGaussian : this.gaussians) {
            dArr[i2] = multivariateGaussian.pMixtureLog10;
            if (multivariateGaussian.pMixtureLog10 != Double.NEGATIVE_INFINITY) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + MathUtils.normalDistributionLog10(multivariateGaussian.mu[i], multivariateGaussian.sigma.get(i, i), variantDatum.annotations[i]);
            }
            i2++;
        }
        return Double.valueOf(nanTolerantLog10SumLog10(dArr));
    }

    public double evaluateDatumMarginalized(VariantDatum variantDatum) {
        int i = 0;
        double d = 0.0d;
        double[] dArr = new double[this.gaussians.size()];
        for (int i2 = 0; i2 < variantDatum.annotations.length; i2++) {
            if (variantDatum.isNull[i2]) {
                for (int i3 = 0; i3 < 20; i3++) {
                    variantDatum.annotations[i2] = Utils.getRandomGenerator().nextGaussian();
                    int i4 = 0;
                    for (MultivariateGaussian multivariateGaussian : this.gaussians) {
                        int i5 = i4;
                        i4++;
                        dArr[i5] = multivariateGaussian.pMixtureLog10 + multivariateGaussian.evaluateDatumLog10(variantDatum);
                    }
                    d += Math.pow(10.0d, nanTolerantLog10SumLog10(dArr));
                    i++;
                }
            }
        }
        return Math.log10(d / i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<MultivariateGaussian> getModelGaussians() {
        return Collections.unmodifiableList(this.gaussians);
    }

    protected int getNumAnnotations() {
        return this.empiricalMu.length;
    }
}
