package com.github.waikatodatamining.matrix.algorithm;

import com.github.waikatodatamining.matrix.core.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixHelper;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithm/PLS1.class */
public class PLS1 extends AbstractSingleResponsePLS {
    private static final long serialVersionUID = 4899661745515419256L;
    protected Matrix m_r_hat;
    protected Matrix m_P;
    protected Matrix m_W;
    protected Matrix m_b_hat;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractSingleResponsePLS, com.github.waikatodatamining.matrix.algorithm.AbstractPLS, com.github.waikatodatamining.matrix.core.LoggingObject
    public void reset() {
        super.reset();
        this.m_r_hat = null;
        this.m_P = null;
        this.m_W = null;
        this.m_b_hat = null;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public String[] getMatrixNames() {
        return new String[]{"r_hat", "P", "W", "b_hat"};
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix getMatrix(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 80:
                if (str.equals("P")) {
                    z = true;
                    break;
                }
                break;
            case 87:
                if (str.equals("W")) {
                    z = 2;
                    break;
                }
                break;
            case 93438270:
                if (str.equals("b_hat")) {
                    z = 3;
                    break;
                }
                break;
            case 1524819671:
                if (str.equals("RegVector")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case KernelPLS.SEED /* 0 */:
                return this.m_r_hat;
            case true:
                return this.m_P;
            case true:
                return this.m_W;
            case true:
                return this.m_b_hat;
            default:
                return null;
        }
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public boolean hasLoadings() {
        return true;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix getLoadings() {
        return getMatrix("P");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public String doPerformInitialization(Matrix matrix, Matrix matrix2) throws Exception {
        Matrix transpose = matrix.transpose();
        Matrix matrix3 = new Matrix(matrix.numColumns(), getNumComponents());
        Matrix matrix4 = new Matrix(matrix.numColumns(), getNumComponents());
        Matrix matrix5 = new Matrix(matrix.numRows(), getNumComponents());
        Matrix matrix6 = new Matrix(getNumComponents(), 1);
        for (int i = 0; i < getNumComponents(); i++) {
            Matrix mul = transpose.mul(matrix2);
            MatrixHelper.normalizeVector(mul);
            matrix3.setColumn(i, mul);
            Matrix mul2 = matrix.mul(mul);
            Matrix transpose2 = mul2.transpose();
            matrix5.setColumn(i, mul2);
            double asDouble = transpose2.mul(matrix2).asDouble() / transpose2.mul(mul2).asDouble();
            matrix6.set(i, 0, asDouble);
            Matrix mul3 = transpose.mul(mul2).mul(1.0d / transpose2.mul(mul2).asDouble());
            Matrix transpose3 = mul3.transpose();
            matrix4.setColumn(i, mul3);
            matrix = matrix.sub(mul2.mul(transpose3));
            matrix2 = matrix2.sub(mul2.mul(asDouble));
        }
        this.m_r_hat = matrix3.mul(matrix4.transpose().mul(matrix3).inverse()).mul(matrix6);
        this.m_P = matrix4;
        this.m_W = matrix3;
        this.m_b_hat = matrix6;
        return null;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    protected Matrix doTransform(Matrix matrix) throws Exception {
        Matrix matrix2 = new Matrix(matrix.numRows(), getNumComponents());
        for (int i = 0; i < matrix.numRows(); i++) {
            Matrix rowAsVector = MatrixHelper.rowAsVector(matrix, i);
            Matrix matrix3 = new Matrix(1, getNumComponents());
            Matrix matrix4 = new Matrix(1, getNumComponents());
            for (int i2 = 0; i2 < getNumComponents(); i2++) {
                matrix3.setColumn(i2, rowAsVector);
                Matrix mul = rowAsVector.mul(this.m_W.getColumn(i2));
                matrix4.setColumn(i2, mul);
                rowAsVector = rowAsVector.sub(this.m_P.getColumn(i2).transpose().mul(mul.asDouble()));
            }
            matrix2.setRow(i, matrix4);
        }
        return matrix2;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public boolean canPredict() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix doPerformPredictions(Matrix matrix) throws Exception {
        Matrix matrix2 = new Matrix(matrix.numRows(), 1);
        for (int i = 0; i < matrix.numRows(); i++) {
            Matrix rowAsVector = MatrixHelper.rowAsVector(matrix, i);
            Matrix matrix3 = new Matrix(1, getNumComponents());
            Matrix matrix4 = new Matrix(1, getNumComponents());
            for (int i2 = 0; i2 < getNumComponents(); i2++) {
                matrix3.setColumn(i2, rowAsVector);
                Matrix mul = rowAsVector.mul(this.m_W.getColumn(i2));
                matrix4.setRow(i2, mul);
                rowAsVector = rowAsVector.sub(this.m_P.getColumn(i2).transpose().mul(mul.asDouble()));
            }
            matrix2.set(i, 0, matrix4.mul(this.m_b_hat).asDouble());
        }
        return matrix2;
    }
}
