package com.github.waikatodatamining.matrix.algorithm.pls;

import com.github.waikatodatamining.matrix.core.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixFactory;
import com.github.waikatodatamining.matrix.core.MatrixHelper;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithm/pls/PLS1.class */
public class PLS1 extends AbstractSingleResponsePLS {
    private static final long serialVersionUID = 4899661745515419256L;
    protected Matrix m_r_hat;
    protected Matrix m_P;
    protected Matrix m_W;
    protected Matrix m_b_hat;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractSingleResponsePLS, com.github.waikatodatamining.matrix.algorithm.AbstractAlgorithm, com.github.waikatodatamining.matrix.core.LoggingObject
    public void reset() {
        super.reset();
        this.m_r_hat = null;
        this.m_P = null;
        this.m_W = null;
        this.m_b_hat = null;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public String[] getMatrixNames() {
        return new String[]{"r_hat", "P", "W", "b_hat"};
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public Matrix getMatrix(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 80:
                if (str.equals("P")) {
                    z = true;
                    break;
                }
                break;
            case 87:
                if (str.equals("W")) {
                    z = 2;
                    break;
                }
                break;
            case 93438270:
                if (str.equals("b_hat")) {
                    z = 3;
                    break;
                }
                break;
            case 1524819671:
                if (str.equals("RegVector")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case KernelPLS.SEED /* 0 */:
                return this.m_r_hat;
            case true:
                return this.m_P;
            case true:
                return this.m_W;
            case true:
                return this.m_b_hat;
            default:
                return null;
        }
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public boolean hasLoadings() {
        return true;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public Matrix getLoadings() {
        return getMatrix("P");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public String doPerformInitialization(Matrix matrix, Matrix matrix2) throws Exception {
        Matrix matrix3 = matrix;
        Matrix zeros = MatrixFactory.zeros(matrix.numColumns(), getNumComponents());
        Matrix zeros2 = MatrixFactory.zeros(matrix.numColumns(), getNumComponents());
        Matrix zeros3 = MatrixFactory.zeros(matrix.numRows(), getNumComponents());
        Matrix zeros4 = MatrixFactory.zeros(getNumComponents(), 1);
        for (int i = 0; i < getNumComponents(); i++) {
            Matrix normalized = matrix3.transpose().mul(matrix2).normalized();
            zeros.setColumn(i, normalized);
            Matrix mul = matrix3.mul(normalized);
            zeros3.setColumn(i, mul);
            double vectorDot = mul.vectorDot(mul);
            zeros4.set(i, 0, mul.vectorDot(matrix2) / vectorDot);
            Matrix div = matrix3.transpose().mul(mul).div(vectorDot);
            zeros2.setColumn(i, div);
            matrix3 = matrix3.sub(mul.mul(div.transpose()));
        }
        this.m_r_hat = zeros.mul(zeros2.transpose().mul(zeros).inverse()).mul(zeros4);
        this.m_P = zeros2;
        this.m_W = zeros;
        this.m_b_hat = zeros4;
        return null;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    protected Matrix doTransform(Matrix matrix) throws Exception {
        Matrix zeros = MatrixFactory.zeros(matrix.numRows(), getNumComponents());
        for (int i = 0; i < matrix.numRows(); i++) {
            Matrix rowAsVector = MatrixHelper.rowAsVector(matrix, i);
            Matrix zeros2 = MatrixFactory.zeros(1, getNumComponents());
            Matrix zeros3 = MatrixFactory.zeros(1, getNumComponents());
            for (int i2 = 0; i2 < getNumComponents(); i2++) {
                zeros2.setColumn(i2, rowAsVector);
                Matrix mul = rowAsVector.mul(this.m_W.getColumn(i2));
                zeros3.setColumn(i2, mul);
                rowAsVector = rowAsVector.sub(this.m_P.getColumn(i2).transpose().mul(mul.asDouble()));
            }
            zeros.setRow(i, zeros3);
        }
        return zeros;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public boolean canPredict() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public Matrix doPerformPredictions(Matrix matrix) throws Exception {
        Matrix zeros = MatrixFactory.zeros(matrix.numRows(), 1);
        for (int i = 0; i < matrix.numRows(); i++) {
            Matrix rowAsVector = MatrixHelper.rowAsVector(matrix, i);
            Matrix zeros2 = MatrixFactory.zeros(1, getNumComponents());
            Matrix zeros3 = MatrixFactory.zeros(1, getNumComponents());
            for (int i2 = 0; i2 < getNumComponents(); i2++) {
                zeros2.setColumn(i2, rowAsVector);
                Matrix mul = rowAsVector.mul(this.m_W.getColumn(i2));
                zeros3.setRow(i2, mul);
                rowAsVector = rowAsVector.sub(this.m_P.getColumn(i2).transpose().mul(mul.asDouble()));
            }
            zeros.set(i, 0, zeros3.mul(this.m_b_hat).asDouble());
        }
        return zeros;
    }
}
