package gov.sandia.cognition.text.topic;

import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorUtil;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.bayesian.AdaptiveRejectionSampling;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.Randomized;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.Collection;
import java.util.Random;

@PublicationReferences(references = {@PublicationReference(author = {"Thomas Hofmann"}, title = "Probabilistic Latent Semantic Analysis", year = 1999, type = PublicationType.Conference, publication = "Proceedings of the Fifteenth Conference on Uncertainty in Artificial Intelligence (UAI)", pages = {289, 296}, url = "http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.33.1187"), @PublicationReference(author = {"Thomas Hofmann"}, title = "Probabilistic Latent Semantic Indexing", year = 1999, type = PublicationType.Conference, publication = "Proceedings of the 22nd Conference of the ACM Special Interest Group on Information Retreival (SIGIR)", pages = {AdaptiveRejectionSampling.DEFAULT_MAX_NUM_POINTS, 57}, url = "http://portal.acm.org/citation.cfm?id=312649"), @PublicationReference(author = {"Thomas Hofmann"}, title = "Unsupervised Learning by Probabilistic Latent Semantic Analysis", year = 2001, type = PublicationType.Journal, publication = "Machine Learning", pages = {177, 196}, url = "http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.130.6341")})
/* loaded from: input_file:gov/sandia/cognition/text/topic/ProbabilisticLatentSemanticAnalysis.class */
public class ProbabilisticLatentSemanticAnalysis extends AbstractAnytimeBatchLearner<Collection<? extends Vectorizable>, Result> implements Randomized, VectorFactoryContainer {
    public static final int DEFAULT_REQUESTED_RANK = 10;
    public static final int DEFAULT_MAX_ITERATIONS = 250;
    public static final double DEFAULT_MINIMUM_CHANGE = 1.0E-10d;
    protected int requestedRank;
    protected double minimumChange;
    protected Random random;
    protected VectorFactory<? extends Vector> vectorFactory;
    protected MatrixFactory<? extends Matrix> matrixFactory;
    protected transient Matrix documentsByTerms;
    protected transient int termCount;
    protected transient int documentCount;
    protected transient int latentCount;
    protected transient LatentData[] latents;
    protected transient double logLikelihood;
    protected transient double changeOfLogLikelihood;
    protected transient Result result;

    /* loaded from: input_file:gov/sandia/cognition/text/topic/ProbabilisticLatentSemanticAnalysis$LatentData.class */
    public static class LatentData {
        int index;
        Matrix pLatentGivenDocumentTerm;
        Vector pTermGivenLatent;
        Vector pDocumentGivenLatent;
        double pLatent;
    }

    /* loaded from: input_file:gov/sandia/cognition/text/topic/ProbabilisticLatentSemanticAnalysis$Result.class */
    public static class Result extends AbstractCloneableSerializable implements Evaluator<Vectorizable, Vector>, VectorInputEvaluator<Vectorizable, Vector>, VectorOutputEvaluator<Vectorizable, Vector> {
        protected int termCount;
        protected int latentCount;
        protected LatentData[] latents;
        protected int maxIterations = ProbabilisticLatentSemanticAnalysis.DEFAULT_MAX_ITERATIONS;
        protected double minimumChange = 1.0E-10d;

        public Result(int i, LatentData[] latentDataArr) {
            this.termCount = i;
            this.latentCount = latentDataArr.length;
            this.latents = latentDataArr;
        }

        @Override // gov.sandia.cognition.evaluator.Evaluator
        public Vector evaluate(Vectorizable vectorizable) {
            Vector convertToVector = vectorizable.convertToVector();
            Matrix createMatrix = MatrixFactory.getDefault().createMatrix(this.latentCount, this.termCount);
            Vector createVector = VectorFactory.getDefault().createVector(this.latentCount, 1.0d / this.latentCount);
            double d = Double.MIN_VALUE;
            for (int i = 1; i <= this.maxIterations; i++) {
                double d2 = d;
                d = step(convertToVector, createMatrix, createVector);
                if (Math.abs(d - d2) <= this.minimumChange) {
                    break;
                }
            }
            return createVector;
        }

        protected double step(Vector vector, Matrix matrix, Vector vector2) {
            for (int i = 0; i < this.termCount; i++) {
                double d = 0.0d;
                for (LatentData latentData : this.latents) {
                    int i2 = latentData.index;
                    double element = latentData.pLatent * vector2.getElement(i2) * latentData.pTermGivenLatent.getElement(i);
                    matrix.setElement(i2, i, element);
                    d += element;
                }
                if (d != 0.0d) {
                    for (LatentData latentData2 : this.latents) {
                        int i3 = latentData2.index;
                        matrix.setElement(i3, i, matrix.getElement(i3, i) / d);
                    }
                }
            }
            for (LatentData latentData3 : this.latents) {
                int i4 = latentData3.index;
                double d2 = 0.0d;
                for (int i5 = 0; i5 < this.termCount; i5++) {
                    d2 += vector.getElement(i5) * matrix.getElement(i4, i5);
                }
                vector2.setElement(i4, d2);
            }
            VectorUtil.divideByNorm1Equals(vector2);
            double d3 = 0.0d;
            for (int i6 = 0; i6 < this.termCount; i6++) {
                double d4 = 0.0d;
                for (LatentData latentData4 : this.latents) {
                    d4 += latentData4.pLatent * latentData4.pTermGivenLatent.getElement(i6) * vector2.getElement(latentData4.index);
                }
                if (d4 != 0.0d) {
                    d3 += vector.getElement(i6) * Math.log(d4);
                }
            }
            return d3;
        }

        @Override // gov.sandia.cognition.math.matrix.VectorInputEvaluator
        public int getInputDimensionality() {
            return this.termCount;
        }

        @Override // gov.sandia.cognition.math.matrix.VectorOutputEvaluator
        public int getOutputDimensionality() {
            return this.latents.length;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/text/topic/ProbabilisticLatentSemanticAnalysis$StatusPrinter.class */
    public static class StatusPrinter extends AbstractIterativeAlgorithmListener {
        protected PrintStream out;

        public StatusPrinter() {
            this(System.out);
        }

        public StatusPrinter(PrintStream printStream) {
            this.out = printStream;
        }

        @Override // gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener, gov.sandia.cognition.algorithm.IterativeAlgorithmListener
        public void stepStarted(IterativeAlgorithm iterativeAlgorithm) {
            ProbabilisticLatentSemanticAnalysis probabilisticLatentSemanticAnalysis = (ProbabilisticLatentSemanticAnalysis) iterativeAlgorithm;
            DecimalFormat decimalFormat = new DecimalFormat("0.00");
            this.out.println("Iteration " + probabilisticLatentSemanticAnalysis.getIteration());
            for (LatentData latentData : probabilisticLatentSemanticAnalysis.result.latents) {
                this.out.println("    Latent " + latentData.index);
                this.out.println("        p(z)   = " + decimalFormat.format(latentData.pLatent));
                this.out.println("        p(t|z) = " + latentData.pTermGivenLatent.toString(decimalFormat));
                this.out.println("        p(d|z) = " + latentData.pDocumentGivenLatent.toString(decimalFormat));
            }
        }

        @Override // gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener, gov.sandia.cognition.algorithm.IterativeAlgorithmListener
        public void stepEnded(IterativeAlgorithm iterativeAlgorithm) {
            ProbabilisticLatentSemanticAnalysis probabilisticLatentSemanticAnalysis = (ProbabilisticLatentSemanticAnalysis) iterativeAlgorithm;
            this.out.println("Log likelihood: " + probabilisticLatentSemanticAnalysis.logLikelihood);
            this.out.println("Change: " + probabilisticLatentSemanticAnalysis.changeOfLogLikelihood);
        }
    }

    public ProbabilisticLatentSemanticAnalysis() {
        this(10);
    }

    public ProbabilisticLatentSemanticAnalysis(Random random) {
        this(10, 1.0E-10d, random);
    }

    public ProbabilisticLatentSemanticAnalysis(int i) {
        this(i, 1.0E-10d, new Random());
    }

    public ProbabilisticLatentSemanticAnalysis(int i, double d, Random random) {
        super(DEFAULT_MAX_ITERATIONS);
        setRequestedRank(i);
        setRandom(random);
        setMinimumChange(d);
        setVectorFactory(VectorFactory.getDefault());
        setMatrixFactory(MatrixFactory.getDefault());
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        Collection<? extends Vectorizable> data = getData();
        this.documentsByTerms = getMatrixFactory().copyRowVectors(data);
        this.termCount = DatasetUtil.getDimensionality(data);
        this.documentCount = data.size();
        this.latentCount = Math.min(this.documentCount, getRequestedRank());
        this.latents = new LatentData[this.latentCount];
        for (int i = 0; i < this.latentCount; i++) {
            LatentData latentData = new LatentData();
            this.latents[i] = latentData;
            latentData.index = i;
            latentData.pLatentGivenDocumentTerm = getMatrixFactory().createMatrix(this.documentCount, this.termCount);
            latentData.pTermGivenLatent = getVectorFactory().createUniformRandom(this.termCount, 0.0d, 1.0d, getRandom());
            VectorUtil.divideByNorm1Equals(latentData.pTermGivenLatent);
            latentData.pDocumentGivenLatent = getVectorFactory().createUniformRandom(this.documentCount, 0.0d, 1.0d, getRandom());
            VectorUtil.divideByNorm1Equals(latentData.pDocumentGivenLatent);
            latentData.pLatent = 1.0d / this.latentCount;
        }
        this.logLikelihood = Double.MIN_VALUE;
        this.changeOfLogLikelihood = 0.0d;
        this.result = new Result(this.termCount, this.latents);
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        for (int i = 0; i < this.documentCount; i++) {
            for (int i2 = 0; i2 < this.termCount; i2++) {
                double d = 0.0d;
                for (LatentData latentData : this.latents) {
                    double element = latentData.pLatent * latentData.pDocumentGivenLatent.getElement(i) * latentData.pTermGivenLatent.getElement(i2);
                    latentData.pLatentGivenDocumentTerm.setElement(i, i2, element);
                    d += element;
                }
                if (d != 0.0d) {
                    for (LatentData latentData2 : this.latents) {
                        latentData2.pLatentGivenDocumentTerm.setElement(i, i2, latentData2.pLatentGivenDocumentTerm.getElement(i, i2) / d);
                    }
                }
            }
        }
        double d2 = 0.0d;
        for (LatentData latentData3 : this.latents) {
            Matrix dotTimes = this.documentsByTerms.dotTimes(latentData3.pLatentGivenDocumentTerm);
            latentData3.pTermGivenLatent = dotTimes.sumOfRows();
            latentData3.pDocumentGivenLatent = dotTimes.sumOfColumns();
            latentData3.pLatent = latentData3.pDocumentGivenLatent.sum();
            VectorUtil.divideByNorm1Equals(latentData3.pTermGivenLatent);
            VectorUtil.divideByNorm1Equals(latentData3.pDocumentGivenLatent);
            d2 += latentData3.pLatent;
        }
        if (d2 != 0.0d) {
            for (LatentData latentData4 : this.latents) {
                latentData4.pLatent /= d2;
            }
        }
        double d3 = this.logLikelihood;
        this.logLikelihood = 0.0d;
        for (int i3 = 0; i3 < this.documentCount; i3++) {
            for (int i4 = 0; i4 < this.termCount; i4++) {
                double d4 = 0.0d;
                for (LatentData latentData5 : this.latents) {
                    d4 += latentData5.pLatent * latentData5.pDocumentGivenLatent.getElement(i3) * latentData5.pTermGivenLatent.getElement(i4);
                }
                if (d4 != 0.0d) {
                    this.logLikelihood += this.documentsByTerms.getElement(i3, i4) * Math.log(d4);
                }
            }
        }
        double d5 = this.logLikelihood - d3;
        this.changeOfLogLikelihood = d5;
        return Math.abs(d5) > getMinimumChange();
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        this.latents = null;
        this.documentsByTerms = null;
    }

    @Override // gov.sandia.cognition.algorithm.AnytimeAlgorithm
    /* renamed from: getResult */
    public Result getResult2() {
        return this.result;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public Random getRandom() {
        return this.random;
    }

    @Override // gov.sandia.cognition.util.Randomized
    public void setRandom(Random random) {
        this.random = random;
    }

    @Override // gov.sandia.cognition.math.matrix.VectorFactoryContainer
    public VectorFactory<? extends Vector> getVectorFactory() {
        return this.vectorFactory;
    }

    public void setVectorFactory(VectorFactory<? extends Vector> vectorFactory) {
        this.vectorFactory = vectorFactory;
    }

    public MatrixFactory<? extends Matrix> getMatrixFactory() {
        return this.matrixFactory;
    }

    public void setMatrixFactory(MatrixFactory<? extends Matrix> matrixFactory) {
        this.matrixFactory = matrixFactory;
    }

    public int getRequestedRank() {
        return this.requestedRank;
    }

    public void setRequestedRank(int i) {
        ArgumentChecker.assertIsPositive("requestedRank", i);
        this.requestedRank = i;
    }

    public double getMinimumChange() {
        return this.minimumChange;
    }

    public void setMinimumChange(double d) {
        ArgumentChecker.assertIsNonNegative("minimumChange", d);
        this.minimumChange = d;
    }
}
