package com.github.waikatodatamining.matrix.algorithms.glsw;

import com.github.waikatodatamining.matrix.algorithms.Center;
import com.github.waikatodatamining.matrix.core.algorithm.SupervisedMatrixAlgorithm;
import com.github.waikatodatamining.matrix.core.exceptions.MatrixAlgorithmsException;
import com.github.waikatodatamining.matrix.core.matrix.Matrix;
import com.github.waikatodatamining.matrix.core.matrix.MatrixFactory;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithms/glsw/GLSW.class */
public class GLSW extends SupervisedMatrixAlgorithm {
    private static final long serialVersionUID = -7474573037658789063L;
    protected double m_Alpha = 0.001d;
    protected Matrix m_G;

    public Matrix getProjectionMatrix() {
        return this.m_G;
    }

    public double getAlpha() {
        return this.m_Alpha;
    }

    public void setAlpha(double d) {
        if (d <= 0.0d) {
            getLogger().warning("Alpha must be > 0 but was " + d + ".");
        } else {
            this.m_Alpha = d;
            reset();
        }
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.ConfiguredMatrixAlgorithm
    protected void doReset() {
        this.m_G = null;
    }

    public String toString() {
        ensureConfigured();
        return "Generalized Least Squares Weighting. Projection Matrix shape: " + this.m_G.shapeString();
    }

    protected void check(Matrix matrix, Matrix matrix2) throws MatrixAlgorithmsException {
        if (matrix.numRows() != matrix2.numRows() || matrix.numColumns() != matrix2.numColumns()) {
            throw new MatrixAlgorithmsException("Matrices X and y must have the same shape");
        }
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.SupervisedMatrixAlgorithm
    protected void doConfigure(Matrix matrix, Matrix matrix2) {
        check(matrix, matrix2);
        Matrix covarianceMatrix = getCovarianceMatrix(matrix, matrix2);
        Matrix eigenvectorMatrix = getEigenvectorMatrix(covarianceMatrix);
        this.m_G = eigenvectorMatrix.mul(getWeightMatrix(covarianceMatrix).inverse()).mul(eigenvectorMatrix.t());
    }

    protected Matrix getEigenvectorMatrix(Matrix matrix) {
        return matrix.getEigenvalueDecompositionV();
    }

    protected Matrix getWeightMatrix(Matrix matrix) {
        Matrix div = matrix.svdS().powElementwise(2.0d).div(this.m_Alpha);
        return div.add(MatrixFactory.eyeLike(div)).sqrt();
    }

    protected Matrix getCovarianceMatrix(Matrix matrix, Matrix matrix2) {
        Center center = new Center();
        Center center2 = new Center();
        Matrix sub = center2.configureAndTransform(matrix2).sub(center.configureAndTransform(matrix));
        return sub.t().mul(sub);
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.MatrixAlgorithm
    protected Matrix doTransform(Matrix matrix) {
        return matrix.mul(this.m_G);
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.SupervisedMatrixAlgorithm, com.github.waikatodatamining.matrix.core.algorithm.ConfiguredMatrixAlgorithm, com.github.waikatodatamining.matrix.core.algorithm.MatrixAlgorithm
    public boolean isNonInvertible() {
        return true;
    }
}
