package com.github.waikatodatamining.matrix.algorithms;

import com.github.waikatodatamining.matrix.core.algorithm.SupervisedMatrixAlgorithmWithResponseTransform;
import com.github.waikatodatamining.matrix.core.exceptions.MatrixAlgorithmsException;
import com.github.waikatodatamining.matrix.core.matrix.Matrix;
import com.github.waikatodatamining.matrix.core.matrix.MatrixFactory;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithms/CCAFilter.class */
public class CCAFilter extends SupervisedMatrixAlgorithmWithResponseTransform {
    private static final long serialVersionUID = 5252111378504552170L;
    protected double m_lambdaX = 0.01d;
    protected double m_lambdaY = 0.01d;
    protected int m_kcca = 1;
    protected Center m_centerX = new Center();
    protected Center m_centerY = new Center();
    protected Matrix m_ProjX;
    protected Matrix m_ProjY;

    public int getKcca() {
        return this.m_kcca;
    }

    public void setKcca(int i) {
        if (i < 1) {
            getLogger().warning("Target dimension kcca must be > 0 but was " + i + ".");
        } else {
            this.m_kcca = i;
            reset();
        }
    }

    public double getLambdaX() {
        return this.m_lambdaX;
    }

    public void setLambdaX(double d) {
        this.m_lambdaX = d;
        reset();
    }

    public double getLambdaY() {
        return this.m_lambdaY;
    }

    public void setLambdaY(double d) {
        this.m_lambdaY = d;
        reset();
    }

    public Matrix getProjectionMatrixX() {
        return this.m_ProjX;
    }

    public Matrix getProjectionMatrixY() {
        return this.m_ProjY;
    }

    public String toString() {
        return "Canonical Correlation Analysis Filter (CCARegression)";
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.ConfiguredMatrixAlgorithm
    protected void doReset() {
        this.m_ProjX = null;
        this.m_ProjY = null;
        this.m_centerX.reset();
        this.m_centerY.reset();
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.SupervisedMatrixAlgorithm
    protected void doConfigure(Matrix matrix, Matrix matrix2) {
        int numColumns = matrix.numColumns();
        int numColumns2 = matrix2.numColumns();
        if (this.m_kcca > Math.min(numColumns, numColumns2)) {
            throw new MatrixAlgorithmsException("Projection dimension must be <= min(X.numColumns, Y.numColumns).");
        }
        Matrix configureAndTransform = this.m_centerX.configureAndTransform(matrix);
        Matrix configureAndTransform2 = this.m_centerY.configureAndTransform(matrix2);
        Matrix mul = MatrixFactory.eye(numColumns).mul(this.m_lambdaX);
        Matrix mul2 = MatrixFactory.eye(numColumns2).mul(this.m_lambdaY);
        Matrix add = configureAndTransform.t().mul(configureAndTransform).add(mul);
        Matrix add2 = configureAndTransform2.t().mul(configureAndTransform2).add(mul2);
        Matrix mul3 = configureAndTransform.t().mul(configureAndTransform2);
        Matrix powMinusHalf = powMinusHalf(add);
        Matrix powMinusHalf2 = powMinusHalf(add2);
        Matrix mul4 = powMinusHalf.mul(mul3).mul(powMinusHalf2);
        Matrix normalized = mul4.svdU().normalized(0);
        Matrix normalized2 = mul4.svdV().normalized(0);
        Matrix subMatrix = normalized.getSubMatrix(0, normalized.numRows(), 0, this.m_kcca);
        Matrix subMatrix2 = normalized2.getSubMatrix(0, normalized2.numRows(), 0, this.m_kcca);
        this.m_ProjX = powMinusHalf.mul(subMatrix);
        this.m_ProjY = powMinusHalf2.mul(subMatrix2);
    }

    protected Matrix powMinusHalf(Matrix matrix) {
        Matrix eigenvaluesSortedDescending = matrix.getEigenvaluesSortedDescending();
        Matrix eigenvectors = matrix.getEigenvectors(true);
        return eigenvectors.mul(MatrixFactory.diag(eigenvaluesSortedDescending).sqrt().inverse()).mul(eigenvectors.t());
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.MatrixAlgorithm
    protected Matrix doTransform(Matrix matrix) {
        return this.m_centerX.transform(matrix).mul(this.m_ProjX);
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.SupervisedMatrixAlgorithm, com.github.waikatodatamining.matrix.core.algorithm.ConfiguredMatrixAlgorithm, com.github.waikatodatamining.matrix.core.algorithm.MatrixAlgorithm
    public boolean isNonInvertible() {
        return true;
    }

    @Override // com.github.waikatodatamining.matrix.core.algorithm.SupervisedMatrixAlgorithmWithResponseTransform
    protected Matrix doTransformResponse(Matrix matrix) {
        return this.m_centerY.transform(matrix).mul(this.m_ProjY);
    }
}
