package gov.sandia.cognition.statistics.bayesian.conjugate;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.bayesian.AbstractBayesianParameter;
import gov.sandia.cognition.statistics.bayesian.BayesianParameter;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.MultivariateStudentTDistribution;
import gov.sandia.cognition.statistics.distribution.NormalInverseWishartDistribution;
import gov.sandia.cognition.util.Pair;
import java.util.Arrays;

@PublicationReferences(references = {@PublicationReference(author = {"Andrew Gelman", "John B. Carlin", "Hal S. Stern", "Donald B. Rubin"}, title = "Bayesian Data Analysis, Second Edition", type = PublicationType.Book, year = 2004, pages = {87, 88}), @PublicationReference(author = {"Wikipedia"}, title = "Conjugate Prior", type = PublicationType.WebPage, year = 2009, url = "http://en.wikipedia.org/wiki/Conjugate_prior")})
/* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/conjugate/MultivariateGaussianMeanCovarianceBayesianEstimator.class */
public class MultivariateGaussianMeanCovarianceBayesianEstimator extends AbstractConjugatePriorBayesianEstimator<Vector, Matrix, MultivariateGaussian, NormalInverseWishartDistribution> implements ConjugatePriorBayesianEstimatorPredictor<Vector, Matrix, MultivariateGaussian, NormalInverseWishartDistribution> {

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/conjugate/MultivariateGaussianMeanCovarianceBayesianEstimator$Parameter.class */
    public static class Parameter extends AbstractBayesianParameter<Matrix, MultivariateGaussian, NormalInverseWishartDistribution> {
        public static final String NAME = "meanAndCovariance";

        public Parameter(MultivariateGaussian multivariateGaussian, NormalInverseWishartDistribution normalInverseWishartDistribution) {
            super(multivariateGaussian, NAME, normalInverseWishartDistribution);
        }

        @Override // gov.sandia.cognition.statistics.DistributionParameter
        public void setValue(Matrix matrix) {
            int inputDimensionality = ((MultivariateGaussian) this.conditionalDistribution).getInputDimensionality();
            if (matrix.getNumRows() != inputDimensionality || matrix.getNumColumns() != inputDimensionality + 1) {
                throw new IllegalArgumentException("Expected (dim x dim+1) Matrix");
            }
            Vector column = matrix.getColumn(0);
            Matrix subMatrix = matrix.getSubMatrix(0, inputDimensionality - 1, 1, inputDimensionality);
            ((MultivariateGaussian) this.conditionalDistribution).setMean(column);
            ((MultivariateGaussian) this.conditionalDistribution).setCovariance(subMatrix);
        }

        /* renamed from: getValue, reason: merged with bridge method [inline-methods] */
        public Matrix m349getValue() {
            int inputDimensionality = ((MultivariateGaussian) this.conditionalDistribution).getInputDimensionality();
            Matrix createMatrix = MatrixFactory.getDefault().createMatrix(inputDimensionality, inputDimensionality + 1);
            createMatrix.setColumn(0, ((MultivariateGaussian) this.conditionalDistribution).getMean());
            createMatrix.setSubMatrix(0, 1, ((MultivariateGaussian) this.conditionalDistribution).getCovariance());
            return createMatrix;
        }
    }

    public MultivariateGaussianMeanCovarianceBayesianEstimator() {
        this(new NormalInverseWishartDistribution());
    }

    public MultivariateGaussianMeanCovarianceBayesianEstimator(int i) {
        this(new NormalInverseWishartDistribution(i));
    }

    public MultivariateGaussianMeanCovarianceBayesianEstimator(NormalInverseWishartDistribution normalInverseWishartDistribution) {
        this(new MultivariateGaussian(normalInverseWishartDistribution.getInputDimensionality()), normalInverseWishartDistribution);
    }

    public MultivariateGaussianMeanCovarianceBayesianEstimator(MultivariateGaussian multivariateGaussian, NormalInverseWishartDistribution normalInverseWishartDistribution) {
        this(new Parameter(multivariateGaussian, normalInverseWishartDistribution));
    }

    protected MultivariateGaussianMeanCovarianceBayesianEstimator(BayesianParameter<Matrix, MultivariateGaussian, NormalInverseWishartDistribution> bayesianParameter) {
        super(bayesianParameter);
    }

    @Override // gov.sandia.cognition.statistics.bayesian.conjugate.ConjugatePriorBayesianEstimator
    public Parameter createParameter(MultivariateGaussian multivariateGaussian, NormalInverseWishartDistribution normalInverseWishartDistribution) {
        return new Parameter(multivariateGaussian, normalInverseWishartDistribution);
    }

    @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public void update(NormalInverseWishartDistribution normalInverseWishartDistribution, Vector vector) {
        update(normalInverseWishartDistribution, (Iterable<? extends Vector>) Arrays.asList(vector));
    }

    public void update(NormalInverseWishartDistribution normalInverseWishartDistribution, Iterable<? extends Vector> iterable) {
        int size = CollectionUtil.size(iterable);
        Pair computeMeanAndCovariance = MultivariateStatisticsUtil.computeMeanAndCovariance(iterable);
        Vector vector = (Vector) computeMeanAndCovariance.getFirst();
        Matrix matrix = (Matrix) computeMeanAndCovariance.getSecond();
        Vector mean = normalInverseWishartDistribution.getGaussian().getMean();
        double covarianceDivisor = normalInverseWishartDistribution.getCovarianceDivisor();
        int degreesOfFreedom = normalInverseWishartDistribution.getInverseWishart().getDegreesOfFreedom();
        Matrix inverseScale = normalInverseWishartDistribution.getInverseWishart().getInverseScale();
        int i = degreesOfFreedom + size;
        double d = covarianceDivisor + size;
        Vector vector2 = (Vector) mean.scale(covarianceDivisor / size);
        vector2.plusEquals(vector);
        vector2.scaleEquals(size / d);
        vector.minusEquals(mean);
        if (size > 1) {
            matrix.scaleEquals(size);
        }
        matrix.plusEquals(inverseScale);
        matrix.plusEquals(vector.outerProduct(vector.scale((size * covarianceDivisor) / d)));
        normalInverseWishartDistribution.getGaussian().setMean(vector2);
        normalInverseWishartDistribution.setCovarianceDivisor(d);
        normalInverseWishartDistribution.getInverseWishart().setDegreesOfFreedom(i);
        normalInverseWishartDistribution.getInverseWishart().setInverseScale(matrix);
    }

    @Override // gov.sandia.cognition.statistics.bayesian.conjugate.ConjugatePriorBayesianEstimator
    public double computeEquivalentSampleSize(NormalInverseWishartDistribution normalInverseWishartDistribution) {
        return normalInverseWishartDistribution.getCovarianceDivisor();
    }

    @Override // gov.sandia.cognition.statistics.bayesian.BayesianEstimatorPredictor
    public MultivariateStudentTDistribution createPredictiveDistribution(NormalInverseWishartDistribution normalInverseWishartDistribution) {
        double degreesOfFreedom = (normalInverseWishartDistribution.getInverseWishart().getDegreesOfFreedom() - normalInverseWishartDistribution.getInverseWishart().getInputDimensionality()) + 1.0d;
        return new MultivariateStudentTDistribution(degreesOfFreedom, normalInverseWishartDistribution.getGaussian().getMean(), normalInverseWishartDistribution.getInverseWishart().getInverseScale().scale((normalInverseWishartDistribution.getCovarianceDivisor() + 1.0d) / (normalInverseWishartDistribution.getCovarianceDivisor() * degreesOfFreedom)).inverse());
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractBatchAndIncrementalLearner, gov.sandia.cognition.learning.algorithm.IncrementalLearner
    public /* bridge */ /* synthetic */ void update(Object obj, Iterable iterable) {
        update((NormalInverseWishartDistribution) obj, (Iterable<? extends Vector>) iterable);
    }
}
