package gov.sandia.cognition.learning.data.feature;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
import gov.sandia.cognition.math.matrix.mtj.decomposition.CholeskyDecompositionMTJ;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;
import java.util.Iterator;

/* loaded from: input_file:gov/sandia/cognition/learning/data/feature/MultivariateDecorrelator.class */
public class MultivariateDecorrelator extends AbstractCloneableSerializable implements Evaluator<Vectorizable, Vector>, VectorInputEvaluator<Vectorizable, Vector>, VectorOutputEvaluator<Vectorizable, Vector> {
    protected MultivariateGaussian gaussian;
    private Matrix covarianceInverseSquareRoot;

    /* loaded from: input_file:gov/sandia/cognition/learning/data/feature/MultivariateDecorrelator$DiagonalCovarianceLearner.class */
    public static class DiagonalCovarianceLearner extends AbstractCloneableSerializable implements BatchLearner<Collection<? extends Vectorizable>, MultivariateDecorrelator> {
        public static final double DEFAULT_DEFAULT_COVARIANCE = 1.0E-5d;
        protected double defaultCovariance;

        public DiagonalCovarianceLearner() {
            this(1.0E-5d);
        }

        public DiagonalCovarianceLearner(double d) {
            setDefaultCovariance(d);
        }

        @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
        public MultivariateDecorrelator learn(Collection<? extends Vectorizable> collection) {
            return MultivariateDecorrelator.learnDiagonalCovariance(collection, getDefaultCovariance());
        }

        public double getDefaultCovariance() {
            return this.defaultCovariance;
        }

        public void setDefaultCovariance(double d) {
            if (d < 0.0d) {
                throw new IllegalArgumentException("defaultCovariance cannot be negative.");
            }
            this.defaultCovariance = d;
        }
    }

    /* loaded from: input_file:gov/sandia/cognition/learning/data/feature/MultivariateDecorrelator$FullCovarianceLearner.class */
    public static class FullCovarianceLearner extends AbstractCloneableSerializable implements BatchLearner<Collection<? extends Vectorizable>, MultivariateDecorrelator> {
        public static final double DEFAULT_DEFAULT_COVARIANCE = 1.0E-5d;
        protected double defaultCovariance;

        public FullCovarianceLearner() {
            this(1.0E-5d);
        }

        public FullCovarianceLearner(double d) {
            setDefaultCovariance(d);
        }

        @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
        public MultivariateDecorrelator learn(Collection<? extends Vectorizable> collection) {
            return MultivariateDecorrelator.learnFullCovariance(collection, getDefaultCovariance());
        }

        public double getDefaultCovariance() {
            return this.defaultCovariance;
        }

        public void setDefaultCovariance(double d) {
            if (d < 0.0d) {
                throw new IllegalArgumentException("defaultCovariance cannot be negative.");
            }
            this.defaultCovariance = d;
        }
    }

    public MultivariateDecorrelator() {
        this((MultivariateGaussian) null);
    }

    public MultivariateDecorrelator(Vector vector, Matrix matrix) {
        this(new MultivariateGaussian(vector, matrix));
    }

    public MultivariateDecorrelator(MultivariateGaussian multivariateGaussian) {
        setGaussian(multivariateGaussian);
    }

    public MultivariateDecorrelator(MultivariateDecorrelator multivariateDecorrelator) {
        this((MultivariateGaussian) ObjectUtil.cloneSafe(multivariateDecorrelator.getGaussian()));
    }

    @Override // gov.sandia.cognition.util.AbstractCloneableSerializable
    /* renamed from: clone */
    public MultivariateDecorrelator mo0clone() {
        MultivariateDecorrelator multivariateDecorrelator = (MultivariateDecorrelator) super.mo0clone();
        multivariateDecorrelator.gaussian = (MultivariateGaussian) ObjectUtil.cloneSafe(this.gaussian);
        multivariateDecorrelator.covarianceInverseSquareRoot = (Matrix) ObjectUtil.cloneSafe(this.covarianceInverseSquareRoot);
        return multivariateDecorrelator;
    }

    @Override // gov.sandia.cognition.evaluator.Evaluator
    public Vector evaluate(Vectorizable vectorizable) {
        return ((Vector) vectorizable.convertToVector().minus(getMean())).times(getCovarianceInverseSquareRoot());
    }

    @Override // gov.sandia.cognition.math.matrix.VectorInputEvaluator
    public int getInputDimensionality() {
        return getGaussian().getInputDimensionality();
    }

    @Override // gov.sandia.cognition.math.matrix.VectorOutputEvaluator
    public int getOutputDimensionality() {
        return getGaussian().getInputDimensionality();
    }

    public Vector getMean() {
        return getGaussian().getMean();
    }

    public Matrix getCovariance() {
        return getGaussian().getCovariance();
    }

    public MultivariateGaussian getGaussian() {
        return this.gaussian;
    }

    public void setGaussian(MultivariateGaussian multivariateGaussian) {
        if (multivariateGaussian == null) {
            this.gaussian = null;
            this.covarianceInverseSquareRoot = null;
        } else {
            this.gaussian = multivariateGaussian.mo0clone();
            this.covarianceInverseSquareRoot = CholeskyDecompositionMTJ.create(DenseMatrixFactoryMTJ.INSTANCE.copyMatrix(multivariateGaussian.getCovarianceInverse())).getR();
        }
    }

    public Matrix getCovarianceInverseSquareRoot() {
        return this.covarianceInverseSquareRoot;
    }

    public static MultivariateDecorrelator learnFullCovariance(Collection<? extends Vectorizable> collection, double d) {
        return new MultivariateDecorrelator(MultivariateGaussian.MaximumLikelihoodEstimator.learn(DatasetUtil.asVectorCollection(collection), d));
    }

    public static MultivariateDecorrelator learnDiagonalCovariance(Collection<? extends Vectorizable> collection, double d) {
        if (collection == null) {
            throw new IllegalArgumentException("values cannot be null.");
        }
        int size = collection.size();
        if (size <= 0) {
            throw new IllegalArgumentException("values cannot be empty.");
        }
        RingAccumulator ringAccumulator = new RingAccumulator();
        Iterator<? extends Vectorizable> it = collection.iterator();
        while (it.hasNext()) {
            ringAccumulator.accumulate((RingAccumulator) it.next().convertToVector());
        }
        Vector vector = (Vector) ringAccumulator.getMean();
        Vector createVector = VectorFactory.getDefault().createVector(vector.getDimensionality());
        Iterator<? extends Vectorizable> it2 = collection.iterator();
        while (it2.hasNext()) {
            Vector vector2 = (Vector) it2.next().convertToVector().minus(vector);
            vector2.dotTimesEquals(vector2);
            createVector.plusEquals(vector2);
        }
        createVector.scaleEquals(1.0d / size);
        createVector.plusEquals(VectorFactory.getDefault().createVector(vector.getDimensionality(), d));
        return new MultivariateDecorrelator(vector, MatrixFactory.getDefault().createDiagonal(createVector));
    }
}
