package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.DefaultCluster;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.bayesian.conjugate.MultivariateGaussianMeanBayesianEstimator;
import gov.sandia.cognition.statistics.bayesian.conjugate.MultivariateGaussianMeanCovarianceBayesianEstimator;
import gov.sandia.cognition.statistics.distribution.BetaDistribution;
import gov.sandia.cognition.statistics.distribution.ChineseRestaurantProcess;
import gov.sandia.cognition.statistics.distribution.GammaDistribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.NormalInverseWishartDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Random;

@PublicationReferences(references = {@PublicationReference(author = {"Radform M. Neal"}, title = "Markov Chain Sampling Methods for Dirichlet Process Mixture Models", type = PublicationType.Journal, year = 2000, publication = "Journal of Computational and Graphical Statistics, Vol. 9, No. 2", pages = {249, 265}, notes = {"Based in part on Algorithm 2 from Neal"}), @PublicationReference(author = {"Michael D. Escobar", "Mike West"}, title = "Bayesian Density Estimation and Inference Using Mixtures", type = PublicationType.Journal, publication = "Journal of the American Statistical Association", year = 1995)})
/* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/DirichletProcessMixtureModel.class */
public class DirichletProcessMixtureModel<ObservationType> extends AbstractMarkovChainMonteCarlo<ObservationType, Sample<ObservationType>> {
    public static final double DEFAULT_ALPHA = 1.0d;
    public static final int DEFAULT_NUM_INITIAL_CLUSTERS = 2;
    public static final boolean DEFAULT_REESTIMATE_ALPHA = true;
    protected Updater<ObservationType> updater;
    private int numInitialClusters;
    protected boolean reestimateAlpha;
    protected double initialAlpha;
    protected transient ProbabilityFunction<ObservationType> conditionalPriorPredictive;
    protected transient double[] clusterWeights;
    protected transient BetaDistribution etaSampler;
    protected transient GammaDistribution alphaInverseSampler;

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/DirichletProcessMixtureModel$DPMMCluster.class */
    public static class DPMMCluster<ObservationType> extends DefaultCluster<ObservationType> {
        private ProbabilityFunction<? super ObservationType> probabilityFunction;

        public DPMMCluster(Collection<? extends ObservationType> collection, ProbabilityFunction<? super ObservationType> probabilityFunction) {
            super(collection);
            setProbabilityFunction(probabilityFunction);
        }

        @Override // gov.sandia.cognition.learning.algorithm.clustering.cluster.DefaultCluster
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public DPMMCluster<ObservationType> mo24clone() {
            DPMMCluster<ObservationType> dPMMCluster = (DPMMCluster) super.mo24clone();
            dPMMCluster.setProbabilityFunction((ProbabilityFunction) ObjectUtil.cloneSafe(getProbabilityFunction()));
            return dPMMCluster;
        }

        public ProbabilityFunction<? super ObservationType> getProbabilityFunction() {
            return this.probabilityFunction;
        }

        public void setProbabilityFunction(ProbabilityFunction<? super ObservationType> probabilityFunction) {
            this.probabilityFunction = probabilityFunction;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/DirichletProcessMixtureModel$DPMMLogConditional.class */
    public static class DPMMLogConditional extends AbstractCloneableSerializable {
        double logConditional = 0.0d;
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/DirichletProcessMixtureModel$MultivariateMeanCovarianceUpdater.class */
    public static class MultivariateMeanCovarianceUpdater extends AbstractCloneableSerializable implements Updater<Vector> {
        private MultivariateGaussianMeanCovarianceBayesianEstimator estimator;

        public MultivariateMeanCovarianceUpdater() {
            this((MultivariateGaussianMeanCovarianceBayesianEstimator) null);
        }

        public MultivariateMeanCovarianceUpdater(int i) {
            this(new MultivariateGaussianMeanCovarianceBayesianEstimator(i));
        }

        public MultivariateMeanCovarianceUpdater(MultivariateGaussianMeanCovarianceBayesianEstimator multivariateGaussianMeanCovarianceBayesianEstimator) {
            this.estimator = multivariateGaussianMeanCovarianceBayesianEstimator;
        }

        /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
        public MultivariateMeanCovarianceUpdater m299clone() {
            MultivariateMeanCovarianceUpdater multivariateMeanCovarianceUpdater = (MultivariateMeanCovarianceUpdater) super.clone();
            multivariateMeanCovarianceUpdater.estimator = (MultivariateGaussianMeanCovarianceBayesianEstimator) ObjectUtil.cloneSafe(this.estimator);
            return multivariateMeanCovarianceUpdater;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel.Updater
        /* renamed from: createPriorPredictive, reason: merged with bridge method [inline-methods] */
        public ProbabilityFunction<Vector> createPriorPredictive2(Iterable<? extends Vector> iterable) {
            return this.estimator.createPredictiveDistribution((NormalInverseWishartDistribution) this.estimator.learn((Iterable) iterable)).getProbabilityFunction();
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel.Updater
        /* renamed from: createClusterPosterior, reason: merged with bridge method [inline-methods] */
        public ProbabilityFunction<Vector> createClusterPosterior2(Iterable<? extends Vector> iterable, Random random) {
            return this.estimator.createConditionalDistribution(((NormalInverseWishartDistribution) this.estimator.learn((Iterable) iterable)).sample(random)).getProbabilityFunction();
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/DirichletProcessMixtureModel$MultivariateMeanUpdater.class */
    public static class MultivariateMeanUpdater extends AbstractCloneableSerializable implements Updater<Vector> {
        protected MultivariateGaussianMeanBayesianEstimator estimator;

        public MultivariateMeanUpdater() {
            this((MultivariateGaussianMeanBayesianEstimator) null);
        }

        public MultivariateMeanUpdater(int i) {
            this(new MultivariateGaussianMeanBayesianEstimator(i));
        }

        public MultivariateMeanUpdater(MultivariateGaussianMeanBayesianEstimator multivariateGaussianMeanBayesianEstimator) {
            this.estimator = multivariateGaussianMeanBayesianEstimator;
        }

        /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
        public MultivariateMeanUpdater m301clone() {
            MultivariateMeanUpdater multivariateMeanUpdater = (MultivariateMeanUpdater) super.clone();
            multivariateMeanUpdater.estimator = (MultivariateGaussianMeanBayesianEstimator) ObjectUtil.cloneSafe(this.estimator);
            return multivariateMeanUpdater;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel.Updater
        /* renamed from: createPriorPredictive */
        public ProbabilityFunction<Vector> createPriorPredictive2(Iterable<? extends Vector> iterable) {
            return this.estimator.createPredictiveDistribution((MultivariateGaussian) this.estimator.learn((Iterable) iterable)).getProbabilityFunction();
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel.Updater
        /* renamed from: createClusterPosterior */
        public ProbabilityFunction<Vector> createClusterPosterior2(Iterable<? extends Vector> iterable, Random random) {
            return this.estimator.createConditionalDistribution(((MultivariateGaussian) this.estimator.learn((Iterable) iterable)).sample(random)).getProbabilityFunction();
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/DirichletProcessMixtureModel$Sample.class */
    public static class Sample<ObservationType> extends AbstractCloneableSerializable {
        protected double alpha;
        protected ArrayList<DPMMCluster<ObservationType>> clusters;
        private Double posteriorLogLikelihood;

        public Sample(double d, ArrayList<DPMMCluster<ObservationType>> arrayList) {
            setAlpha(d);
            setClusters(arrayList);
            setPosteriorLogLikelihood(null);
        }

        /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
        public Sample<ObservationType> m303clone() {
            Sample<ObservationType> clone = super.clone();
            clone.setClusters(ObjectUtil.cloneSmartElementsAsArrayList(getClusters()));
            clone.setPosteriorLogLikelihood(null);
            return clone;
        }

        public double computePosteriorLogLikelihood(Iterable<? extends ObservationType> iterable) {
            int numClusters = getNumClusters();
            int size = CollectionUtil.size(iterable);
            double d = 0.0d;
            for (ObservationType observationtype : iterable) {
                double d2 = 1.0E-100d;
                for (int i = 0; i < numClusters; i++) {
                    DPMMCluster<ObservationType> dPMMCluster = this.clusters.get(i);
                    d2 += dPMMCluster.getMembers().size() * ((Double) dPMMCluster.getProbabilityFunction().evaluate(observationtype)).doubleValue();
                }
                d += Math.log(d2);
            }
            ChineseRestaurantProcess.PMF pmf = new ChineseRestaurantProcess.PMF(getAlpha(), size);
            Vector createVector = VectorFactory.getDefault().createVector(numClusters);
            for (int i2 = 0; i2 < numClusters; i2++) {
                createVector.setElement(i2, this.clusters.get(i2).getMembers().size());
            }
            return d + pmf.logEvaluate(createVector);
        }

        public double computePosteriorLogLikelihood(int i, double d) {
            int numClusters = getNumClusters();
            ChineseRestaurantProcess.PMF pmf = new ChineseRestaurantProcess.PMF(getAlpha(), i);
            Vector createVector = VectorFactory.getDefault().createVector(numClusters);
            for (int i2 = 0; i2 < numClusters; i2++) {
                createVector.setElement(i2, this.clusters.get(i2).getMembers().size());
            }
            return pmf.logEvaluate(createVector) + d;
        }

        public void removeUnusedClusters() {
            int i = 0;
            while (i < getNumClusters()) {
                if (this.clusters.get(i).getMembers().size() <= 0) {
                    this.clusters.remove(i);
                    i--;
                }
                i++;
            }
        }

        public double getAlpha() {
            return this.alpha;
        }

        protected void setAlpha(double d) {
            if (d <= 0.0d) {
                throw new IllegalArgumentException("Alpha must be > 0.0 ");
            }
            this.alpha = d;
        }

        public int getNumClusters() {
            return this.clusters.size();
        }

        public ArrayList<DPMMCluster<ObservationType>> getClusters() {
            return this.clusters;
        }

        protected void setClusters(ArrayList<DPMMCluster<ObservationType>> arrayList) {
            this.clusters = arrayList;
        }

        public Double getPosteriorLogLikelihood() {
            return this.posteriorLogLikelihood;
        }

        public void setPosteriorLogLikelihood(Double d) {
            this.posteriorLogLikelihood = d;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/statistics/bayesian/DirichletProcessMixtureModel$Updater.class */
    public interface Updater<ObservationType> extends CloneableSerializable {
        /* renamed from: createPriorPredictive */
        ProbabilityFunction<ObservationType> createPriorPredictive2(Iterable<? extends ObservationType> iterable);

        /* renamed from: createClusterPosterior */
        ProbabilityFunction<ObservationType> createClusterPosterior2(Iterable<? extends ObservationType> iterable, Random random);
    }

    public DirichletProcessMixtureModel() {
        setReestimateAlpha(true);
        setInitialAlpha(1.0d);
        setNumInitialClusters(2);
    }

    @Override // gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo, gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    /* renamed from: clone */
    public DirichletProcessMixtureModel<ObservationType> mo1clone() {
        DirichletProcessMixtureModel<ObservationType> dirichletProcessMixtureModel = (DirichletProcessMixtureModel) super.mo1clone();
        dirichletProcessMixtureModel.setUpdater((Updater) ObjectUtil.cloneSafe(getUpdater()));
        return dirichletProcessMixtureModel;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo
    protected void mcmcUpdate() {
        if (this.conditionalPriorPredictive == null) {
            this.conditionalPriorPredictive = this.updater.createPriorPredictive2((Iterable) this.data);
        }
        int numClusters = ((Sample) this.currentParameter).getNumClusters();
        DPMMLogConditional dPMMLogConditional = new DPMMLogConditional();
        ArrayList<Collection<ObservationType>> assignObservationsToClusters = assignObservationsToClusters(numClusters, dPMMLogConditional);
        int size = CollectionUtil.size((Collection) this.data);
        if (this.previousParameter != 0 && ((Sample) this.previousParameter).posteriorLogLikelihood == null) {
            ((Sample) this.previousParameter).posteriorLogLikelihood = Double.valueOf(((Sample) this.previousParameter).computePosteriorLogLikelihood(size, dPMMLogConditional.logConditional));
        }
        ((Sample) this.currentParameter).clusters = updateClusters(assignObservationsToClusters);
        if (getReestimateAlpha()) {
            ((Sample) this.currentParameter).alpha = updateAlpha(((Sample) this.currentParameter).alpha, size);
        }
    }

    protected ArrayList<DPMMCluster<ObservationType>> updateClusters(ArrayList<Collection<ObservationType>> arrayList) {
        DPMMCluster<ObservationType> createCluster;
        int size = arrayList.size();
        ArrayList<DPMMCluster<ObservationType>> arrayList2 = new ArrayList<>(size);
        for (int i = 0; i < size; i++) {
            Collection<ObservationType> collection = arrayList.get(i);
            if (collection.size() > 1 && (createCluster = createCluster(collection, this.updater)) != null) {
                arrayList2.add(createCluster);
            }
        }
        return arrayList2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected ArrayList<Collection<ObservationType>> assignObservationsToClusters(int i, DPMMLogConditional dPMMLogConditional) {
        if (this.clusterWeights == null || this.clusterWeights.length != i + 1) {
            this.clusterWeights = new double[i + 1];
        }
        ArrayList<Collection<ObservationType>> arrayList = new ArrayList<>(i + 1);
        for (int i2 = 0; i2 < i + 1; i2++) {
            arrayList.add(new LinkedList());
        }
        for (Object obj : (Collection) this.data) {
            arrayList.get(assignObservationToCluster(obj, this.clusterWeights, dPMMLogConditional)).add(obj);
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public int assignObservationToCluster(ObservationType observationtype, double[] dArr, DPMMLogConditional dPMMLogConditional) {
        double d = ((Sample) this.currentParameter).alpha;
        int numClusters = ((Sample) this.currentParameter).getNumClusters();
        double doubleValue = d * ((Double) this.conditionalPriorPredictive.evaluate(observationtype)).doubleValue();
        dArr[numClusters] = doubleValue;
        double d2 = doubleValue;
        double d3 = 1.0E-100d;
        for (int i = 0; i < numClusters; i++) {
            DPMMCluster<ObservationType> dPMMCluster = ((Sample) this.currentParameter).clusters.get(i);
            int size = dPMMCluster.getMembers().size();
            if (size > 0) {
                double doubleValue2 = ((Double) dPMMCluster.getProbabilityFunction().evaluate(observationtype)).doubleValue();
                double d4 = (size - 1) * doubleValue2;
                dArr[i] = d4;
                d2 += d4;
                d3 += size * doubleValue2;
            } else {
                dArr[i] = 0.0d;
            }
        }
        dPMMLogConditional.logConditional += Math.log(d3);
        double nextDouble = d2 * this.random.nextDouble();
        for (int i2 = 0; i2 < numClusters + 1; i2++) {
            nextDouble -= dArr[i2];
            if (nextDouble <= 0.0d) {
                return i2;
            }
        }
        throw new IllegalArgumentException("Did not select cluster: " + d2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DPMMCluster<ObservationType> createCluster(Collection<ObservationType> collection, Updater<ObservationType> updater) {
        if (collection != null && collection.size() > 0.0d) {
            return new DPMMCluster<>(collection, updater.createClusterPosterior2(collection, this.random));
        }
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected double updateAlpha(double d, int i) {
        if (this.etaSampler == null) {
            this.etaSampler = new BetaDistribution();
        }
        this.etaSampler.setAlpha(d + 1.0d);
        this.etaSampler.setBeta(i);
        double log = Math.log(((Double) this.etaSampler.sample(this.random)).doubleValue());
        int numClusters = ((Sample) this.currentParameter).getNumClusters();
        double d2 = ((1.0d + numClusters) - 1.0d) / (i * (1.0d - log));
        double nextDouble = this.random.nextDouble();
        if (this.alphaInverseSampler == null) {
            this.alphaInverseSampler = new GammaDistribution();
        }
        if (nextDouble < d2) {
            this.alphaInverseSampler.setShape(1.0d + numClusters);
        } else {
            this.alphaInverseSampler.setShape((1.0d + numClusters) - 1.0d);
        }
        this.alphaInverseSampler.setScale(1.0d - log);
        return 1.0d / ((Double) this.alphaInverseSampler.sample(this.random)).doubleValue();
    }

    @Override // gov.sandia.cognition.statistics.bayesian.AbstractMarkovChainMonteCarlo
    public Sample<ObservationType> createInitialLearnedObject() {
        ArrayList arrayList = new ArrayList(getNumInitialClusters());
        ProbabilityFunction<ObservationType> createClusterPosterior2 = this.updater.createClusterPosterior2((Iterable) this.data, this.random);
        ArrayList asArrayList = CollectionUtil.asArrayList((Iterable) this.data);
        for (int i = 0; i < getNumInitialClusters(); i++) {
            arrayList.add(new DPMMCluster(asArrayList, createClusterPosterior2));
        }
        return new Sample<>(getInitialAlpha(), arrayList);
    }

    public Updater<ObservationType> getUpdater() {
        return this.updater;
    }

    public void setUpdater(Updater<ObservationType> updater) {
        this.updater = updater;
    }

    public int getNumInitialClusters() {
        return this.numInitialClusters;
    }

    public void setNumInitialClusters(int i) {
        this.numInitialClusters = i;
    }

    public boolean getReestimateAlpha() {
        return this.reestimateAlpha;
    }

    public void setReestimateAlpha(boolean z) {
        this.reestimateAlpha = z;
    }

    public double getInitialAlpha() {
        return this.initialAlpha;
    }

    public void setInitialAlpha(double d) {
        this.initialAlpha = d;
    }
}
