package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminantWithBias;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import java.util.Collection;

@PublicationReferences(references = {@PublicationReference(author = {"Wikipedia"}, title = "Linear regression", type = PublicationType.WebPage, year = 2008, url = "http://en.wikipedia.org/wiki/Linear_regression"), @PublicationReference(author = {"Wikipedia"}, title = "Tikhonov regularization", type = PublicationType.WebPage, year = 2011, url = "http://en.wikipedia.org/wiki/Tikhonov_regularization", notes = {"Despite what Wikipedia says, this is always called Ridge Regression"})})
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/regression/MultivariateLinearRegression.class */
public class MultivariateLinearRegression extends AbstractCloneableSerializable implements SupervisedBatchLearner<Vector, Vector, MultivariateDiscriminantWithBias> {
    public static final double DEFAULT_REGULARIZATION = 0.0d;
    public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1.0E-10d;
    private boolean usePseudoInverse;
    private double regularization;

    public MultivariateLinearRegression() {
        setUsePseudoInverse(true);
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public MultivariateLinearRegression m112clone() {
        return (MultivariateLinearRegression) super.clone();
    }

    @Override // gov.sandia.cognition.learning.algorithm.BatchLearner
    public MultivariateDiscriminantWithBias learn(Collection<? extends InputOutputPair<? extends Vector, Vector>> collection) {
        Matrix transpose;
        InputOutputPair inputOutputPair = (InputOutputPair) CollectionUtil.getFirst(collection);
        int dimensionality = ((Vector) inputOutputPair.getInput()).getDimensionality();
        int dimensionality2 = ((Vector) inputOutputPair.getOutput()).getDimensionality();
        int size = collection.size();
        Matrix createMatrix = MatrixFactory.getDefault().createMatrix(dimensionality + 1, size);
        Matrix createMatrix2 = MatrixFactory.getDefault().createMatrix(size, dimensionality + 1);
        Matrix createMatrix3 = MatrixFactory.getDefault().createMatrix(dimensionality2, size);
        Matrix createMatrix4 = MatrixFactory.getDefault().createMatrix(size, dimensionality2);
        int i = 0;
        Vector copyValues = VectorFactory.getDefault().copyValues(new double[]{1.0d});
        for (InputOutputPair<? extends Vector, Vector> inputOutputPair2 : collection) {
            Vector output = inputOutputPair2.getOutput();
            Vector stack = inputOutputPair2.getInput().convertToVector().stack(copyValues);
            double weight = DatasetUtil.getWeight(inputOutputPair2);
            if (weight != 1.0d) {
                stack.scaleEquals(weight);
                output = (Vector) output.scale(weight);
            }
            createMatrix2.setRow(i, stack);
            createMatrix.setColumn(i, stack);
            createMatrix3.setColumn(i, output);
            createMatrix4.setRow(i, output);
            i++;
        }
        if (getUsePseudoInverse()) {
            transpose = createMatrix2.pseudoInverse(1.0E-10d).times(createMatrix4).transpose();
        } else {
            Matrix times = createMatrix.times(createMatrix2);
            if (this.regularization > 0.0d) {
                for (int i2 = 0; i2 < dimensionality + 1; i2++) {
                    times.setElement(i2, i2, times.getElement(i2, i2) + this.regularization);
                }
            }
            transpose = times.solve(createMatrix3.times(createMatrix2).transpose()).transpose();
        }
        return new MultivariateDiscriminantWithBias(transpose.getSubMatrix(0, dimensionality2 - 1, 0, dimensionality - 1), transpose.getColumn(dimensionality));
    }

    public boolean getUsePseudoInverse() {
        return this.usePseudoInverse;
    }

    public void setUsePseudoInverse(boolean z) {
        this.usePseudoInverse = z;
    }

    public double getRegularization() {
        return this.regularization;
    }

    public void setRegularization(double d) {
        ArgumentChecker.assertIsNonNegative("regularization", d);
        this.regularization = d;
    }
}
