package com.github.waikatodatamining.matrix.algorithm;

import com.github.waikatodatamining.matrix.core.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixHelper;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithm/OPLS.class */
public class OPLS extends AbstractSingleResponsePLS {
    private static final long serialVersionUID = -6097279189841762321L;
    protected Matrix m_Porth;
    protected Matrix m_Torth;
    protected Matrix m_Worth;
    protected Matrix m_Xosc;
    protected AbstractPLS m_BasePLS;

    public AbstractPLS getBasePLS() {
        return this.m_BasePLS;
    }

    public void setBasePLS(AbstractPLS abstractPLS) {
        this.m_BasePLS = abstractPLS;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractSingleResponsePLS, com.github.waikatodatamining.matrix.algorithm.AbstractPLS, com.github.waikatodatamining.matrix.core.LoggingObject
    public void reset() {
        super.reset();
        this.m_Porth = null;
        this.m_Worth = null;
        this.m_Torth = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS, com.github.waikatodatamining.matrix.core.LoggingObject
    public void initialize() {
        super.initialize();
        setBasePLS(new PLS1());
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public String[] getMatrixNames() {
        return new String[]{"P_orth", "W_orth", "T_orth"};
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix getMatrix(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1913480666:
                if (str.equals("P_orth")) {
                    z = false;
                    break;
                }
                break;
            case -1798964062:
                if (str.equals("T_orth")) {
                    z = 2;
                    break;
                }
                break;
            case -1713076609:
                if (str.equals("W_orth")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case KernelPLS.SEED /* 0 */:
                return this.m_Porth;
            case true:
                return this.m_Worth;
            case true:
                return this.m_Torth;
            default:
                return null;
        }
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public boolean hasLoadings() {
        return true;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix getLoadings() {
        return getMatrix("P_orth");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public String doPerformInitialization(Matrix matrix, Matrix matrix2) throws Exception {
        Matrix copy = matrix.copy();
        Matrix transpose = copy.transpose();
        this.m_Worth = new Matrix(matrix.numColumns(), getNumComponents());
        this.m_Porth = new Matrix(matrix.numColumns(), getNumComponents());
        this.m_Torth = new Matrix(matrix.numRows(), getNumComponents());
        Matrix mul = transpose.mul(matrix2).mul(invL2Squared(matrix2));
        MatrixHelper.normalizeVector(mul);
        for (int i = 0; i < getNumComponents(); i++) {
            Matrix mul2 = copy.mul(mul).mul(invL2Squared(mul));
            Matrix mul3 = transpose.mul(mul2).mul(invL2Squared(mul2));
            Matrix sub = mul3.sub(mul.mul(mul.transpose().mul(mul3).mul(invL2Squared(mul)).asDouble()));
            MatrixHelper.normalizeVector(sub);
            Matrix mul4 = copy.mul(sub).mul(invL2Squared(sub));
            Matrix mul5 = transpose.mul(mul4).mul(invL2Squared(mul4));
            copy = copy.sub(mul4.mul(mul5.transpose()));
            transpose = copy.transpose();
            this.m_Worth.setColumn(i, sub);
            this.m_Torth.setColumn(i, mul4);
            this.m_Porth.setColumn(i, mul5);
        }
        this.m_Xosc = copy.copy();
        this.m_BasePLS.initialize(doTransform(matrix), matrix2);
        return null;
    }

    protected double invL2Squared(Matrix matrix) {
        double norm2 = matrix.norm2();
        return 1.0d / (norm2 * norm2);
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    protected Matrix doTransform(Matrix matrix) {
        return matrix.sub(matrix.mul(this.m_Worth).mul(this.m_Porth.transpose()));
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public boolean canPredict() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix doPerformPredictions(Matrix matrix) throws Exception {
        return this.m_BasePLS.predict(transform(matrix));
    }
}
