package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.AbstractSufficientStatistic;
import gov.sandia.cognition.statistics.ClosedFormDistribution;
import gov.sandia.cognition.statistics.Distribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

@PublicationReferences(references = {@PublicationReference(author = {"Christopher M. Bishop"}, title = "Pattern Recognition and Machine Learning", type = PublicationType.Book, year = 2006, pages = {152, 159}), @PublicationReference(author = {"Hanna M. Wallach"}, title = "Introduction to Gaussian Process Regression", type = PublicationType.Misc, year = 2005, url = "http://www.cs.umass.edu/~wallach/talks/gp_intro.pdf"), @PublicationReference(author = {"Wikipedia"}, title = "Bayesian linear regression", type = PublicationType.WebPage, year = 2010, url = "http://en.wikipedia.org/wiki/Bayesian_linear_regression")})
/* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianLinearRegression.class */
public class BayesianLinearRegression extends AbstractCloneableSerializable implements BayesianRegression<Double, MultivariateGaussian> {
    public static final double DEFAULT_OUTPUT_VARIANCE = 1.0d;
    public static final double DEFAULT_WEIGHT_VARIANCE = 1.0d;
    protected double outputVariance;
    protected MultivariateGaussian weightPrior;

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianLinearRegression$IncrementalEstimator.class */
    public static class IncrementalEstimator extends BayesianLinearRegression implements IncrementalLearner<InputOutputPair<? extends Vectorizable, Double>, SufficientStatistic> {

        /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianLinearRegression$IncrementalEstimator$SufficientStatistic.class */
        public class SufficientStatistic extends AbstractSufficientStatistic<InputOutputPair<? extends Vectorizable, Double>, MultivariateGaussian> {
            private Vector z;
            private Matrix covarianceInverse;

            public SufficientStatistic(MultivariateGaussian multivariateGaussian) {
                if (multivariateGaussian != null) {
                    this.covarianceInverse = multivariateGaussian.getCovarianceInverse().clone();
                    this.z = this.covarianceInverse.times(multivariateGaussian.getMean());
                    this.count = 1L;
                } else {
                    this.covarianceInverse = null;
                    this.z = null;
                    this.count = 0L;
                }
            }

            @Override // gov.sandia.cognition.statistics.SufficientStatistic
            public void update(InputOutputPair<? extends Vectorizable, Double> inputOutputPair) {
                this.count++;
                Vector convertToVector = inputOutputPair.getInput().convertToVector();
                Vector clone = convertToVector.clone();
                double doubleValue = inputOutputPair.getOutput().doubleValue();
                double weight = DatasetUtil.getWeight(inputOutputPair) / IncrementalEstimator.this.outputVariance;
                if (weight != 1.0d) {
                    clone.scaleEquals(weight);
                }
                if (this.covarianceInverse == null) {
                    this.covarianceInverse = convertToVector.outerProduct(clone);
                } else {
                    this.covarianceInverse.plusEquals(convertToVector.outerProduct(clone));
                }
                if (doubleValue != 1.0d) {
                    clone.scaleEquals(doubleValue);
                }
                if (this.z == null) {
                    this.z = clone;
                } else {
                    this.z.plusEquals(clone);
                }
            }

            /* renamed from: create, reason: merged with bridge method [inline-methods] */
            public MultivariateGaussian.PDF m335create() {
                MultivariateGaussian.PDF pdf = new MultivariateGaussian.PDF(getDimensionality());
                create((MultivariateGaussian) pdf);
                return pdf;
            }

            @Override // gov.sandia.cognition.statistics.SufficientStatistic
            public void create(MultivariateGaussian multivariateGaussian) {
                multivariateGaussian.setMean(getMean());
                multivariateGaussian.setCovarianceInverse(getCovarianceInverse());
            }

            public Matrix getCovarianceInverse() {
                return this.covarianceInverse;
            }

            public Vector getZ() {
                return this.z;
            }

            public Vector getMean() {
                return this.covarianceInverse.inverse().times(this.z);
            }

            public int getDimensionality() {
                return getZ().getDimensionality();
            }
        }

        public IncrementalEstimator(int i) {
            super(i);
        }

        public IncrementalEstimator(double d, MultivariateGaussian multivariateGaussian) {
            super(d, multivariateGaussian);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public SufficientStatistic createInitialLearnedObject() {
            return new SufficientStatistic(getWeightPrior());
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.learning.algorithm.BatchLearner
        public MultivariateGaussian.PDF learn(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> collection) {
            SufficientStatistic createInitialLearnedObject = createInitialLearnedObject();
            update(createInitialLearnedObject, (Iterable<? extends InputOutputPair<? extends Vectorizable, Double>>) collection);
            return createInitialLearnedObject.m335create();
        }

        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public void update(SufficientStatistic sufficientStatistic, InputOutputPair<? extends Vectorizable, Double> inputOutputPair) {
            sufficientStatistic.update(inputOutputPair);
        }

        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public void update(SufficientStatistic sufficientStatistic, Iterable<? extends InputOutputPair<? extends Vectorizable, Double>> iterable) {
            sufficientStatistic.update((Iterable) iterable);
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.statistics.bayesian.BayesianRegression
        public /* bridge */ /* synthetic */ Evaluator<? super Vectorizable, ? extends ClosedFormDistribution<Double>> createPredictiveDistribution(MultivariateGaussian multivariateGaussian) {
            return super.createPredictiveDistribution(multivariateGaussian);
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression, gov.sandia.cognition.statistics.bayesian.BayesianRegression
        /* renamed from: createConditionalDistribution */
        public /* bridge */ /* synthetic */ Distribution<Double> createConditionalDistribution2(Vectorizable vectorizable, Vector vector) {
            return super.createConditionalDistribution2(vectorizable, vector);
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ CloneableSerializable mo334clone() {
            return super.mo333clone();
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianLinearRegression
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ Object mo334clone() throws CloneNotSupportedException {
            return super.mo333clone();
        }
    }

    @PublicationReference(author = {"Christopher M. Bishop"}, title = "Pattern Recognition and Machine Learning", type = PublicationType.Book, year = 2006, pages = {156})
    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianLinearRegression$PredictiveDistribution.class */
    public class PredictiveDistribution extends AbstractCloneableSerializable implements Evaluator<Vectorizable, UnivariateGaussian.PDF> {
        private MultivariateGaussian posterior;

        public PredictiveDistribution(MultivariateGaussian multivariateGaussian) {
            this.posterior = multivariateGaussian;
        }

        public UnivariateGaussian.PDF evaluate(Vectorizable vectorizable) {
            Vector convertToVector = vectorizable.convertToVector();
            return new UnivariateGaussian.PDF(convertToVector.dotProduct(this.posterior.getMean()), convertToVector.times(this.posterior.getCovariance()).dotProduct(convertToVector) + BayesianLinearRegression.this.outputVariance);
        }
    }

    public BayesianLinearRegression(int i) {
        this(1.0d, new MultivariateGaussian(VectorFactory.getDefault().createVector(i), MatrixFactory.getDefault().createIdentity(i, i).scale(1.0d)));
    }

    public BayesianLinearRegression(double d, MultivariateGaussian multivariateGaussian) {
        setOutputVariance(d);
        setWeightPrior(multivariateGaussian);
    }

    @Override // 
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public BayesianLinearRegression mo334clone() {
        BayesianLinearRegression bayesianLinearRegression = (BayesianLinearRegression) super.clone();
        bayesianLinearRegression.setWeightPrior((MultivariateGaussian) ObjectUtil.cloneSafe(getWeightPrior()));
        return bayesianLinearRegression;
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public MultivariateGaussian.PDF learn(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> collection) {
        MultivariateGaussian weightPrior = getWeightPrior();
        RingAccumulator ringAccumulator = new RingAccumulator();
        Matrix clone = weightPrior.getCovarianceInverse().clone();
        ringAccumulator.accumulate(clone);
        RingAccumulator ringAccumulator2 = new RingAccumulator();
        ringAccumulator2.accumulate(clone.times(weightPrior.getMean()));
        for (InputOutputPair<? extends Vectorizable, Double> inputOutputPair : collection) {
            Vector convertToVector = inputOutputPair.getInput().convertToVector();
            Vector clone2 = convertToVector.clone();
            double weight = DatasetUtil.getWeight(inputOutputPair) / this.outputVariance;
            if (weight != 1.0d) {
                clone2.scaleEquals(weight);
            }
            ringAccumulator.accumulate(convertToVector.outerProduct(clone2));
            double doubleValue = inputOutputPair.getOutput().doubleValue();
            if (doubleValue != 1.0d) {
                clone2.scaleEquals(doubleValue);
            }
            ringAccumulator2.accumulate(clone2);
        }
        Matrix inverse = ringAccumulator.getSum().inverse();
        return new MultivariateGaussian.PDF(inverse.times(ringAccumulator2.getSum()), inverse);
    }

    @Override // gov.sandia.cognition.statistics.bayesian.BayesianRegression
    /* renamed from: createConditionalDistribution, reason: merged with bridge method [inline-methods] */
    public Distribution<Double> createConditionalDistribution2(Vectorizable vectorizable, Vector vector) {
        return new UnivariateGaussian(vectorizable.convertToVector().dotProduct(vector), getOutputVariance());
    }

    public MultivariateGaussian getWeightPrior() {
        return this.weightPrior;
    }

    public void setWeightPrior(MultivariateGaussian multivariateGaussian) {
        this.weightPrior = multivariateGaussian;
    }

    public double getOutputVariance() {
        return this.outputVariance;
    }

    public void setOutputVariance(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("outputVariance must be > 0.0");
        }
        this.outputVariance = d;
    }

    @Override // gov.sandia.cognition.statistics.bayesian.BayesianRegression
    public PredictiveDistribution createPredictiveDistribution(MultivariateGaussian multivariateGaussian) {
        return new PredictiveDistribution(multivariateGaussian);
    }
}
