package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.AbstractSufficientStatistic;
import gov.sandia.cognition.statistics.ClosedFormDistribution;
import gov.sandia.cognition.statistics.Distribution;
import gov.sandia.cognition.statistics.distribution.InverseGammaDistribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussianInverseGammaDistribution;
import gov.sandia.cognition.statistics.distribution.StudentTDistribution;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

@PublicationReferences(references = {@PublicationReference(author = {"Christopher M. Bishop"}, title = "Pattern Recognition and Machine Learning", type = PublicationType.Book, year = 2006, pages = {152, 159}), @PublicationReference(author = {"Jan Drugowitsch"}, title = "Bayesian Linear Regression", type = PublicationType.Misc, year = 2009, url = "http://www.bcs.rochester.edu/people/jdrugowitsch/code/bayes_linear_notes_0.1.1.pdf")})
/* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianRobustLinearRegression.class */
public class BayesianRobustLinearRegression extends AbstractCloneableSerializable implements BayesianRegression<Double, MultivariateGaussianInverseGammaDistribution> {
    public static final double DEFAULT_WEIGHT_VARIANCE = 1.0d;
    private MultivariateGaussian weightPrior;
    private InverseGammaDistribution outputVariance;

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianRobustLinearRegression$IncrementalEstimator.class */
    public static class IncrementalEstimator extends BayesianRobustLinearRegression implements IncrementalLearner<InputOutputPair<? extends Vectorizable, Double>, SufficientStatistic> {

        /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianRobustLinearRegression$IncrementalEstimator$SufficientStatistic.class */
        public class SufficientStatistic extends AbstractSufficientStatistic<InputOutputPair<? extends Vectorizable, Double>, MultivariateGaussianInverseGammaDistribution> {
            private double outputSumSquared;
            private Vector z;
            private Matrix covarianceInverse;

            public SufficientStatistic(MultivariateGaussianInverseGammaDistribution multivariateGaussianInverseGammaDistribution) {
                if (multivariateGaussianInverseGammaDistribution == null) {
                    this.covarianceInverse = null;
                    this.z = null;
                    this.count = 0L;
                    this.outputSumSquared = 0.0d;
                    return;
                }
                Vector mean = multivariateGaussianInverseGammaDistribution.getMean();
                this.covarianceInverse = multivariateGaussianInverseGammaDistribution.getGaussian().getCovarianceInverse().clone();
                this.z = this.covarianceInverse.times(mean);
                double shape = multivariateGaussianInverseGammaDistribution.getInverseGamma().getShape();
                double scale = multivariateGaussianInverseGammaDistribution.getInverseGamma().getScale();
                this.count = (long) Math.ceil(2.0d * shape);
                this.outputSumSquared = (2.0d * scale) + mean.dotProduct(this.z);
            }

            @Override // gov.sandia.cognition.statistics.SufficientStatistic
            public void update(InputOutputPair<? extends Vectorizable, Double> inputOutputPair) {
                this.count++;
                Vector convertToVector = inputOutputPair.getInput().convertToVector();
                Vector clone = convertToVector.clone();
                double doubleValue = inputOutputPair.getOutput().doubleValue();
                double weight = DatasetUtil.getWeight(inputOutputPair);
                if (weight != 1.0d) {
                    clone.scaleEquals(weight);
                }
                if (this.covarianceInverse == null) {
                    this.covarianceInverse = convertToVector.outerProduct(clone);
                } else {
                    this.covarianceInverse.plusEquals(convertToVector.outerProduct(clone));
                }
                if (doubleValue != 1.0d) {
                    clone.scaleEquals(doubleValue);
                }
                if (this.z == null) {
                    this.z = clone;
                } else {
                    this.z.plusEquals(clone);
                }
                this.outputSumSquared += doubleValue * doubleValue;
            }

            /* renamed from: create, reason: merged with bridge method [inline-methods] */
            public MultivariateGaussianInverseGammaDistribution m338create() {
                MultivariateGaussianInverseGammaDistribution multivariateGaussianInverseGammaDistribution = new MultivariateGaussianInverseGammaDistribution(getDimensionality());
                create(multivariateGaussianInverseGammaDistribution);
                return multivariateGaussianInverseGammaDistribution;
            }

            @Override // gov.sandia.cognition.statistics.SufficientStatistic
            public void create(MultivariateGaussianInverseGammaDistribution multivariateGaussianInverseGammaDistribution) {
                multivariateGaussianInverseGammaDistribution.getGaussian().setMean(getMean());
                multivariateGaussianInverseGammaDistribution.getGaussian().setCovarianceInverse(getCovarianceInverse());
                multivariateGaussianInverseGammaDistribution.getInverseGamma().setShape(getShape());
                multivariateGaussianInverseGammaDistribution.getInverseGamma().setScale(getScale());
            }

            public Matrix getCovarianceInverse() {
                return this.covarianceInverse;
            }

            public Vector getZ() {
                return this.z;
            }

            public Vector getMean() {
                return this.covarianceInverse.inverse().times(this.z);
            }

            public int getDimensionality() {
                return getZ().getDimensionality();
            }

            public double getOutputSumSquared() {
                return this.outputSumSquared;
            }

            public double getShape() {
                return getCount() / 2.0d;
            }

            public double getScale() {
                Vector mean = getMean();
                return 0.5d * (this.outputSumSquared - mean.times(this.covarianceInverse).dotProduct(mean));
            }
        }

        public IncrementalEstimator(int i) {
            super(i);
        }

        public IncrementalEstimator(InverseGammaDistribution inverseGammaDistribution, MultivariateGaussian multivariateGaussian) {
            super(inverseGammaDistribution, multivariateGaussian);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public SufficientStatistic createInitialLearnedObject() {
            return new SufficientStatistic(new MultivariateGaussianInverseGammaDistribution(getWeightPrior(), getOutputVariance()));
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianRobustLinearRegression, gov.sandia.cognition.learning.algorithm.BatchLearner
        public MultivariateGaussianInverseGammaDistribution learn(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> collection) {
            SufficientStatistic createInitialLearnedObject = createInitialLearnedObject();
            update(createInitialLearnedObject, (Iterable<? extends InputOutputPair<? extends Vectorizable, Double>>) collection);
            return createInitialLearnedObject.m338create();
        }

        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public void update(SufficientStatistic sufficientStatistic, InputOutputPair<? extends Vectorizable, Double> inputOutputPair) {
            sufficientStatistic.update(inputOutputPair);
        }

        @Override // gov.sandia.cognition.learning.algorithm.IncrementalLearner
        public void update(SufficientStatistic sufficientStatistic, Iterable<? extends InputOutputPair<? extends Vectorizable, Double>> iterable) {
            sufficientStatistic.update((Iterable) iterable);
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianRobustLinearRegression, gov.sandia.cognition.statistics.bayesian.BayesianRegression
        public /* bridge */ /* synthetic */ Evaluator<? super Vectorizable, ? extends ClosedFormDistribution<Double>> createPredictiveDistribution(MultivariateGaussianInverseGammaDistribution multivariateGaussianInverseGammaDistribution) {
            return super.createPredictiveDistribution(multivariateGaussianInverseGammaDistribution);
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianRobustLinearRegression, gov.sandia.cognition.statistics.bayesian.BayesianRegression
        /* renamed from: createConditionalDistribution */
        public /* bridge */ /* synthetic */ Distribution<Double> createConditionalDistribution2(Vectorizable vectorizable, Vector vector) {
            return super.createConditionalDistribution2(vectorizable, vector);
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianRobustLinearRegression
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ CloneableSerializable mo337clone() {
            return super.mo336clone();
        }

        @Override // gov.sandia.cognition.statistics.bayesian.BayesianRobustLinearRegression
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ Object mo337clone() throws CloneNotSupportedException {
            return super.mo336clone();
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/BayesianRobustLinearRegression$PredictiveDistribution.class */
    public class PredictiveDistribution extends AbstractCloneableSerializable implements Evaluator<Vectorizable, StudentTDistribution> {
        private MultivariateGaussianInverseGammaDistribution posterior;

        public PredictiveDistribution(MultivariateGaussianInverseGammaDistribution multivariateGaussianInverseGammaDistribution) {
            this.posterior = multivariateGaussianInverseGammaDistribution;
        }

        public StudentTDistribution evaluate(Vectorizable vectorizable) {
            Vector convertToVector = vectorizable.convertToVector();
            double dotProduct = convertToVector.dotProduct(this.posterior.getMean());
            return new StudentTDistribution(this.posterior.getInverseGamma().getShape() * 2.0d, dotProduct, (this.posterior.getInverseGamma().getShape() / this.posterior.getInverseGamma().getScale()) / (1.0d + convertToVector.times(this.posterior.getGaussian().getCovariance()).dotProduct(convertToVector)));
        }
    }

    public BayesianRobustLinearRegression(int i) {
        this(new InverseGammaDistribution(), new MultivariateGaussian(VectorFactory.getDefault().createVector(i), MatrixFactory.getDefault().createIdentity(i, i).scale(1.0d)));
    }

    public BayesianRobustLinearRegression(InverseGammaDistribution inverseGammaDistribution, MultivariateGaussian multivariateGaussian) {
        setWeightPrior(multivariateGaussian);
        setOutputVariance(inverseGammaDistribution);
    }

    @Override // 
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public BayesianRobustLinearRegression mo337clone() {
        BayesianRobustLinearRegression bayesianRobustLinearRegression = (BayesianRobustLinearRegression) super.clone();
        bayesianRobustLinearRegression.setWeightPrior((MultivariateGaussian) ObjectUtil.cloneSafe(getWeightPrior()));
        bayesianRobustLinearRegression.setOutputVariance((InverseGammaDistribution) ObjectUtil.cloneSafe(getOutputVariance()));
        return bayesianRobustLinearRegression;
    }

    public MultivariateGaussian getWeightPrior() {
        return this.weightPrior;
    }

    public void setWeightPrior(MultivariateGaussian multivariateGaussian) {
        this.weightPrior = multivariateGaussian;
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public MultivariateGaussianInverseGammaDistribution learn(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> collection) {
        MultivariateGaussian multivariateGaussian = this.weightPrior;
        RingAccumulator ringAccumulator = new RingAccumulator();
        Matrix covarianceInverse = multivariateGaussian.getCovarianceInverse();
        ringAccumulator.accumulate(covarianceInverse);
        RingAccumulator ringAccumulator2 = new RingAccumulator();
        ringAccumulator2.accumulate(covarianceInverse.times(multivariateGaussian.getMean()));
        InverseGammaDistribution inverseGammaDistribution = this.outputVariance;
        double shape = inverseGammaDistribution.getShape();
        double scale = inverseGammaDistribution.getScale();
        double d = 0.0d;
        for (InputOutputPair<? extends Vectorizable, Double> inputOutputPair : collection) {
            Vector convertToVector = inputOutputPair.getInput().convertToVector();
            Vector clone = convertToVector.clone();
            double weight = DatasetUtil.getWeight(inputOutputPair);
            if (weight != 1.0d) {
                clone.scaleEquals(weight);
            }
            ringAccumulator.accumulate(convertToVector.outerProduct(clone));
            double doubleValue = inputOutputPair.getOutput().doubleValue();
            if (doubleValue != 1.0d) {
                clone.scaleEquals(doubleValue);
            }
            ringAccumulator2.accumulate(clone);
            d += doubleValue * doubleValue;
            shape += 0.5d;
        }
        Matrix sum = ringAccumulator.getSum();
        Matrix inverse = sum.inverse();
        Vector times = inverse.times(ringAccumulator2.getSum());
        return new MultivariateGaussianInverseGammaDistribution(new MultivariateGaussian(times, inverse), new InverseGammaDistribution(shape, scale + (0.5d * (d - times.times(sum).dotProduct(times)))));
    }

    public InverseGammaDistribution getOutputVariance() {
        return this.outputVariance;
    }

    public void setOutputVariance(InverseGammaDistribution inverseGammaDistribution) {
        this.outputVariance = inverseGammaDistribution;
    }

    @Override // gov.sandia.cognition.statistics.bayesian.BayesianRegression
    /* renamed from: createConditionalDistribution */
    public Distribution<Double> createConditionalDistribution2(Vectorizable vectorizable, Vector vector) {
        return new UnivariateGaussian(vectorizable.convertToVector().dotProduct(vector), getOutputVariance().getMean().doubleValue());
    }

    @Override // gov.sandia.cognition.statistics.bayesian.BayesianRegression
    public PredictiveDistribution createPredictiveDistribution(MultivariateGaussianInverseGammaDistribution multivariateGaussianInverseGammaDistribution) {
        return new PredictiveDistribution(multivariateGaussianInverseGammaDistribution);
    }
}
