package com.github.waikatodatamining.matrix.algorithm;

import Jama.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixHelper;
import com.github.waikatodatamining.matrix.transformation.Center;
import com.github.waikatodatamining.matrix.transformation.kernel.AbstractKernel;
import com.github.waikatodatamining.matrix.transformation.kernel.RBFKernel;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithm/KernelPLS.class */
public class KernelPLS extends AbstractMultiResponsePLS {
    private static final long serialVersionUID = -2760078672082710402L;
    public static final int SEED = 0;
    protected Matrix m_K_orig;
    protected Matrix m_K_deflated;
    protected Matrix m_T;
    protected Matrix m_U;
    protected Matrix m_P;
    protected Matrix m_Q;
    protected Matrix m_B_RHS;
    protected Matrix m_X;
    protected AbstractKernel m_Kernel;
    protected double m_Tol;
    protected int m_MaxIter;
    protected Center m_CenterX;
    protected Center m_CenterY;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS, com.github.waikatodatamining.matrix.core.LoggingObject
    public void initialize() {
        super.initialize();
        setKernel(new RBFKernel());
        setTol(1.0E-6d);
        setMaxIter(500);
        this.m_CenterX = new Center();
        this.m_CenterY = new Center();
    }

    public AbstractKernel getKernel() {
        return this.m_Kernel;
    }

    public void setKernel(AbstractKernel abstractKernel) {
        this.m_Kernel = abstractKernel;
    }

    public int getMaxIter() {
        return this.m_MaxIter;
    }

    public void setMaxIter(int i) {
        this.m_MaxIter = i;
    }

    public double getTol() {
        return this.m_Tol;
    }

    public void setTol(double d) {
        this.m_Tol = d;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractMultiResponsePLS
    protected int getMinColumnsResponse() {
        return 1;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractMultiResponsePLS
    protected int getMaxColumnsResponse() {
        return -1;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public String doPerformInitialization(Matrix matrix, Matrix matrix2) throws Exception {
        int numComponents = getNumComponents();
        this.m_X = matrix;
        this.m_X = this.m_CenterX.transform(this.m_X);
        Matrix transform = this.m_CenterY.transform(matrix2);
        int rowDimension = this.m_X.getRowDimension();
        int columnDimension = transform.getColumnDimension();
        Matrix matrix3 = new Matrix(columnDimension, 1);
        Matrix matrix4 = new Matrix(rowDimension, 1);
        Matrix matrix5 = new Matrix(rowDimension, 1);
        Matrix identity = Matrix.identity(rowDimension, rowDimension);
        this.m_T = new Matrix(rowDimension, numComponents);
        this.m_U = new Matrix(rowDimension, numComponents);
        this.m_P = new Matrix(rowDimension, numComponents);
        this.m_Q = new Matrix(columnDimension, numComponents);
        this.m_K_orig = this.m_Kernel.applyMatrix(this.m_X, this.m_X);
        this.m_K_orig = centralizeTrainInKernelSpace(this.m_K_orig);
        this.m_K_deflated = this.m_K_orig.copy();
        for (int i = 0; i < numComponents; i++) {
            int i2 = 0;
            Matrix randn = MatrixHelper.randn(rowDimension, 1, 0 + i);
            double d = this.m_Tol * 10.0d;
            while (d > this.m_Tol && i2 < this.m_MaxIter) {
                matrix4 = this.m_K_deflated.times(randn);
                matrix5 = matrix4.copy();
                MatrixHelper.normalizeVector(matrix4);
                matrix3 = transform.transpose().times(matrix4);
                Matrix matrix6 = randn;
                randn = transform.times(matrix3);
                MatrixHelper.normalizeVector(randn);
                i2++;
                d = MatrixHelper.l2VectorNorm(randn.minus(matrix6));
            }
            Matrix minus = identity.minus(matrix4.times(matrix4.transpose()));
            this.m_K_deflated = minus.times(this.m_K_deflated).times(minus);
            transform = transform.minus(matrix4.times(matrix3.transpose()));
            Matrix times = this.m_K_deflated.transpose().times(matrix5).times(1.0d / matrix5.transpose().times(matrix5).get(0, 0));
            MatrixHelper.setColumnVector(matrix4, this.m_T, i);
            MatrixHelper.setColumnVector(randn, this.m_U, i);
            MatrixHelper.setColumnVector(matrix3, this.m_Q, i);
            MatrixHelper.setColumnVector(times, this.m_P, i);
        }
        this.m_B_RHS = this.m_T.transpose().times(this.m_K_orig).times(this.m_U).inverse().times(this.m_Q.transpose());
        return null;
    }

    protected Matrix centralizeTrainInKernelSpace(Matrix matrix) {
        int rowDimension = this.m_X.getRowDimension();
        Matrix identity = Matrix.identity(rowDimension, rowDimension);
        Matrix matrix2 = new Matrix(rowDimension, 1, 1.0d);
        Matrix minus = identity.minus(matrix2.times(matrix2.transpose()).times(1.0d / rowDimension));
        return minus.times(matrix).times(minus);
    }

    protected Matrix centralizeTestInKernelSpace(Matrix matrix) {
        int rowDimension = this.m_X.getRowDimension();
        return matrix.minus(new Matrix(matrix.getRowDimension(), rowDimension, 1.0d / rowDimension).times(this.m_K_orig)).times(Matrix.identity(rowDimension, rowDimension).minus(new Matrix(rowDimension, rowDimension, 1.0d / rowDimension)));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix doPerformPredictions(Matrix matrix) {
        return this.m_CenterY.inverseTransform(doTransform(matrix).times(this.m_B_RHS));
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    protected Matrix doTransform(Matrix matrix) {
        return centralizeTestInKernelSpace(this.m_Kernel.applyMatrix(this.m_CenterX.transform(matrix), this.m_X)).times(this.m_U);
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public String[] getMatrixNames() {
        return new String[]{"K", "T", "U", "P", "Q"};
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix getMatrix(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 75:
                if (str.equals("K")) {
                    z = false;
                    break;
                }
                break;
            case 80:
                if (str.equals("P")) {
                    z = 3;
                    break;
                }
                break;
            case 81:
                if (str.equals("Q")) {
                    z = 4;
                    break;
                }
                break;
            case 84:
                if (str.equals("T")) {
                    z = true;
                    break;
                }
                break;
            case 85:
                if (str.equals("U")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case SEED /* 0 */:
                return this.m_K_deflated;
            case true:
                return this.m_T;
            case true:
                return this.m_U;
            case true:
                return this.m_P;
            case true:
                return this.m_Q;
            default:
                return null;
        }
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public boolean hasLoadings() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractMultiResponsePLS, com.github.waikatodatamining.matrix.algorithm.AbstractPLS, com.github.waikatodatamining.matrix.core.LoggingObject
    public void reset() {
        super.reset();
        this.m_K_orig = null;
        this.m_K_deflated = null;
        this.m_T = null;
        this.m_U = null;
        this.m_P = null;
        this.m_Q = null;
        this.m_B_RHS = null;
        this.m_X = null;
        this.m_CenterX = new Center();
        this.m_CenterY = new Center();
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix getLoadings() {
        return this.m_T;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public boolean canPredict() {
        return true;
    }
}
