package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.Random;
import java.util.Set;

@PublicationReference(author = {"Wikipedia"}, title = "Categoical Distribution", type = PublicationType.WebPage, year = 2011, url = "http://en.wikipedia.org/wiki/Categorical_distribution")
/* loaded from: input_file:gov/sandia/cognition/statistics/distribution/CategoricalDistribution.class */
public class CategoricalDistribution extends AbstractDistribution<Vector> implements ClosedFormComputableDiscreteDistribution<Vector> {
    public static final int DEFAULT_NUM_CLASSES = 2;
    protected Vector parameters;

    /* loaded from: input_file:gov/sandia/cognition/statistics/distribution/CategoricalDistribution$PMF.class */
    public static class PMF extends CategoricalDistribution implements ProbabilityMassFunction<Vector>, VectorInputEvaluator<Vector, Double> {
        public PMF() {
        }

        public PMF(int i) {
            super(i);
        }

        public PMF(Vector vector) {
            super(vector);
        }

        public PMF(CategoricalDistribution categoricalDistribution) {
            super(categoricalDistribution);
        }

        @Override // gov.sandia.cognition.statistics.ProbabilityMassFunction
        public double getEntropy() {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }

        @Override // gov.sandia.cognition.statistics.ProbabilityFunction
        public double logEvaluate(Vector vector) {
            return Math.log(evaluate(vector).doubleValue());
        }

        @Override // gov.sandia.cognition.evaluator.Evaluator
        public Double evaluate(Vector vector) {
            this.parameters.assertSameDimensionality(vector);
            double d = -1.0d;
            int inputDimensionality = getInputDimensionality();
            double d2 = 0.0d;
            for (int i = 0; i < inputDimensionality; i++) {
                double element = this.parameters.getElement(i);
                d2 += element;
                double element2 = vector.getElement(i);
                if (element2 == 1.0d) {
                    if (d >= 0.0d) {
                        throw new IllegalArgumentException("input must only have one entry equal to 1.0!");
                    }
                    d = element;
                } else if (element2 != 0.0d) {
                    throw new IllegalArgumentException("input entries must be either 0.0 or 1.0");
                }
            }
            if (d < 0.0d) {
                throw new IllegalArgumentException("input must have one entry equal to 1.0!");
            }
            return Double.valueOf(d / d2);
        }

        @Override // gov.sandia.cognition.statistics.distribution.CategoricalDistribution, gov.sandia.cognition.statistics.ComputableDistribution
        public PMF getProbabilityFunction() {
            return this;
        }

        @Override // gov.sandia.cognition.statistics.distribution.CategoricalDistribution, gov.sandia.cognition.statistics.DistributionWithMean
        public /* bridge */ /* synthetic */ Object getMean() {
            return super.getMean();
        }

        @Override // gov.sandia.cognition.statistics.distribution.CategoricalDistribution, gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ CloneableSerializable mo0clone() {
            return super.mo0clone();
        }

        @Override // gov.sandia.cognition.statistics.distribution.CategoricalDistribution, gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ Vectorizable mo0clone() {
            return super.mo0clone();
        }

        @Override // gov.sandia.cognition.statistics.distribution.CategoricalDistribution, gov.sandia.cognition.util.AbstractCloneableSerializable
        /* renamed from: clone */
        public /* bridge */ /* synthetic */ Object mo0clone() throws CloneNotSupportedException {
            return super.mo0clone();
        }
    }

    public CategoricalDistribution() {
        this(2);
    }

    public CategoricalDistribution(int i) {
        this(VectorFactory.getDefault().createVector(i, 1.0d));
    }

    public CategoricalDistribution(Vector vector) {
        setParameters(vector);
    }

    public CategoricalDistribution(CategoricalDistribution categoricalDistribution) {
        this((Vector) ObjectUtil.cloneSafe(categoricalDistribution.getParameters()));
    }

    @Override // gov.sandia.cognition.util.AbstractCloneableSerializable, gov.sandia.cognition.util.CloneableSerializable
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public CategoricalDistribution mo0clone() {
        CategoricalDistribution categoricalDistribution = (CategoricalDistribution) super.mo0clone();
        categoricalDistribution.setParameters((Vector) ObjectUtil.cloneSafe(getParameters()));
        return categoricalDistribution;
    }

    public Vector getParameters() {
        return this.parameters;
    }

    public void setParameters(Vector vector) {
        int dimensionality = vector.getDimensionality();
        if (dimensionality < 2) {
            throw new IllegalArgumentException("Dimensionality must be >= 2");
        }
        for (int i = 0; i < dimensionality; i++) {
            if (vector.getElement(i) < 0.0d) {
                throw new IllegalArgumentException("All parameter elements must be >= 0.0");
            }
        }
        this.parameters = vector;
    }

    @Override // gov.sandia.cognition.statistics.Distribution
    public ArrayList<Vector> sample(Random random, int i) {
        ArrayList asArrayList = CollectionUtil.asArrayList(getDomain());
        int size = asArrayList.size();
        double[] dArr = new double[size];
        double d = 0.0d;
        for (int i2 = 0; i2 < size; i2++) {
            d += this.parameters.getElement(i2);
            dArr[i2] = d;
        }
        return ProbabilityMassFunctionUtil.sampleMultiple(dArr, d, asArrayList, random, i);
    }

    @Override // gov.sandia.cognition.statistics.DistributionWithMean
    public Vector getMean() {
        return (Vector) this.parameters.scale(this.parameters.norm1());
    }

    @Override // gov.sandia.cognition.math.matrix.Vectorizable
    public Vector convertToVector() {
        return this.parameters.mo0clone();
    }

    @Override // gov.sandia.cognition.math.matrix.Vectorizable
    public void convertFromVector(Vector vector) {
        this.parameters.assertSameDimensionality(vector);
        setParameters(vector);
    }

    public int getInputDimensionality() {
        return getParameters().getDimensionality();
    }

    @Override // gov.sandia.cognition.statistics.DiscreteDistribution
    public Set<Vector> getDomain() {
        int inputDimensionality = getInputDimensionality();
        LinkedHashSet linkedHashSet = new LinkedHashSet(inputDimensionality);
        for (int i = 0; i < inputDimensionality; i++) {
            Vector createVector = VectorFactory.getDefault().createVector(inputDimensionality);
            createVector.setElement(i, 1.0d);
            linkedHashSet.add(createVector);
        }
        return linkedHashSet;
    }

    @Override // gov.sandia.cognition.statistics.DiscreteDistribution
    public int getDomainSize() {
        return getInputDimensionality();
    }

    @Override // gov.sandia.cognition.statistics.ComputableDistribution
    public PMF getProbabilityFunction() {
        return new PMF(this);
    }
}
