package com.github.waikatodatamining.matrix.algorithm;

import com.github.waikatodatamining.matrix.core.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixFactory;
import com.github.waikatodatamining.matrix.transformation.Standardize;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithm/NIPALS.class */
public class NIPALS extends AbstractMultiResponsePLS {
    private static final long serialVersionUID = -2760078672082710402L;
    protected Matrix m_XScores;
    protected Matrix m_YScores;
    protected Matrix m_XLoadings;
    protected Matrix m_YLoadings;
    protected Matrix m_XWeights;
    protected Matrix m_YWeights;
    protected Matrix m_XRotations;
    protected Matrix m_YRotations;
    protected Matrix m_X;
    protected Matrix m_Coef;
    protected double m_Tol;
    protected int m_MaxIter;
    protected boolean m_NormYWeights;
    protected Standardize m_StandardizeX;
    protected Standardize m_StandardizeY;
    protected DeflationMode m_deflationMode;

    /* loaded from: input_file:com/github/waikatodatamining/matrix/algorithm/NIPALS$DeflationMode.class */
    public enum DeflationMode {
        CANONICAL,
        REGRESSION
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/github/waikatodatamining/matrix/algorithm/NIPALS$NipalsLoopResult.class */
    public class NipalsLoopResult {
        Matrix xWeights;
        Matrix yWeights;
        int iterations;

        public NipalsLoopResult(Matrix matrix, Matrix matrix2, int i) {
            this.xWeights = matrix;
            this.yWeights = matrix2;
            this.iterations = i;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS, com.github.waikatodatamining.matrix.core.LoggingObject
    public void initialize() {
        super.initialize();
        setTol(1.0E-6d);
        setMaxIter(500);
        setNormYWeights(false);
        this.m_StandardizeX = new Standardize();
        this.m_StandardizeY = new Standardize();
        setDeflationMode(DeflationMode.REGRESSION);
    }

    public boolean isNormYWeights() {
        return this.m_NormYWeights;
    }

    public void setNormYWeights(boolean z) {
        this.m_NormYWeights = z;
    }

    public int getMaxIter() {
        return this.m_MaxIter;
    }

    public void setMaxIter(int i) {
        if (i < 0) {
            this.m_Logger.warning("Maximum iterations parameter must be positive but was " + i + ".");
        } else {
            this.m_MaxIter = i;
            reset();
        }
    }

    public double getTol() {
        return this.m_Tol;
    }

    public void setTol(double d) {
        if (d < 0.0d) {
            this.m_Logger.warning("Tolerance parameter must be positive but was " + d + ".");
        } else {
            this.m_Tol = d;
            reset();
        }
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractMultiResponsePLS
    protected int getMinColumnsResponse() {
        return 1;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractMultiResponsePLS
    protected int getMaxColumnsResponse() {
        return -1;
    }

    public DeflationMode getDeflationMode() {
        return this.m_deflationMode;
    }

    public void setDeflationMode(DeflationMode deflationMode) {
        this.m_deflationMode = deflationMode;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public String doPerformInitialization(Matrix matrix, Matrix matrix2) throws Exception {
        getLogger();
        Matrix transform = this.m_StandardizeX.transform(matrix);
        Matrix transform2 = this.m_StandardizeY.transform(matrix2);
        int numRows = transform.numRows();
        int numColumns = transform.numColumns();
        int numColumns2 = transform2.numColumns();
        int numComponents = getNumComponents();
        this.m_XScores = MatrixFactory.zeros(numRows, numComponents);
        this.m_YScores = MatrixFactory.zeros(numRows, numComponents);
        this.m_XWeights = MatrixFactory.zeros(numColumns, numComponents);
        this.m_YWeights = MatrixFactory.zeros(numColumns2, numComponents);
        this.m_XLoadings = MatrixFactory.zeros(numColumns, numComponents);
        this.m_YLoadings = MatrixFactory.zeros(numColumns2, numComponents);
        Matrix zeros = MatrixFactory.zeros(numColumns2, 1);
        int i = 0;
        while (true) {
            if (i < numComponents) {
                NipalsLoopResult nipalsLoop = nipalsLoop(transform, transform2);
                Matrix matrix3 = nipalsLoop.xWeights;
                Matrix matrix4 = nipalsLoop.yWeights;
                Matrix mul = transform.mul(matrix3);
                Matrix div = transform2.mul(matrix4).div(matrix4.norm2squared());
                if (mul.norm2squared() < 1.0E-10d) {
                    this.m_Logger.warning("X scores are null at component " + i);
                } else {
                    Matrix div2 = transform.t().mul(mul).div(mul.norm2squared());
                    transform.subi(mul.mul(div2.t()));
                    switch (this.m_deflationMode) {
                        case CANONICAL:
                            zeros = transform2.t().mul(div).div(div.norm2squared());
                            transform2.subi(div.mul(zeros.t()));
                            break;
                        case REGRESSION:
                            zeros = transform2.t().mul(mul).div(mul.norm2squared());
                            transform2.subi(mul.mul(zeros.t()));
                            break;
                    }
                    this.m_XScores.setColumn(i, mul);
                    this.m_YScores.setColumn(i, div);
                    this.m_XWeights.setColumn(i, matrix3);
                    this.m_YWeights.setColumn(i, matrix4);
                    this.m_XLoadings.setColumn(i, div2);
                    this.m_YLoadings.setColumn(i, zeros);
                    i++;
                }
            }
        }
        this.m_X = transform;
        this.m_XRotations = this.m_XWeights.mul(this.m_XLoadings.t().mul(this.m_XWeights).inverse());
        if (transform2.numColumns() > 1) {
            this.m_YRotations = this.m_YWeights.mul(this.m_YLoadings.t().mul(this.m_YWeights).inverse());
        } else {
            this.m_YRotations = MatrixFactory.filled(1, 1, 1.0d);
        }
        this.m_Coef = this.m_XRotations.mul(this.m_YLoadings.t()).scaleByVector(MatrixFactory.fromColumn(this.m_StandardizeY.getStdDevs()));
        return null;
    }

    protected NipalsLoopResult nipalsLoop(Matrix matrix, Matrix matrix2) {
        Matrix div;
        Matrix div2;
        int i = 0;
        Matrix column = matrix2.getColumn(0);
        Matrix zeros = MatrixFactory.zeros(matrix.numColumns(), 1);
        while (true) {
            div = matrix.t().mul(column).div(column.norm2squared());
            if (div.norm2squared() < 1.0E-16d) {
                div.addi(1.0E-16d);
            }
            div.divi(Math.sqrt(div.norm2squared()) + 1.0E-16d);
            Matrix mul = matrix.mul(div);
            div2 = matrix2.t().mul(mul).div(mul.norm2squared());
            if (this.m_NormYWeights) {
                div2.divi(Math.sqrt(div2.norm2squared()) + 1.0E-16d);
            }
            column = matrix2.mul(div2).div(div2.norm2squared() + 1.0E-16d);
            if (div.sub(zeros).norm2squared() < this.m_Tol || matrix2.numColumns() == 1 || i >= this.m_MaxIter) {
                break;
            }
            i++;
        }
        return new NipalsLoopResult(div, div2, i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix doPerformPredictions(Matrix matrix) {
        Matrix transform = this.m_StandardizeX.transform(matrix);
        return transform.mul(this.m_Coef).addByVector(MatrixFactory.fromColumn(this.m_StandardizeY.getMeans()));
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    protected Matrix doTransform(Matrix matrix) {
        return this.m_StandardizeX.transform(matrix).mul(this.m_XRotations);
    }

    protected Matrix doTransformResponse(Matrix matrix) {
        return this.m_StandardizeY.transform(matrix).mul(this.m_YRotations);
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public String[] getMatrixNames() {
        return new String[]{"T", "U", "P", "Q"};
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix getMatrix(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 80:
                if (str.equals("P")) {
                    z = 2;
                    break;
                }
                break;
            case 81:
                if (str.equals("Q")) {
                    z = 3;
                    break;
                }
                break;
            case 84:
                if (str.equals("T")) {
                    z = false;
                    break;
                }
                break;
            case 85:
                if (str.equals("U")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case KernelPLS.SEED /* 0 */:
                return this.m_XScores;
            case true:
                return this.m_YScores;
            case true:
                return this.m_XLoadings;
            case true:
                return this.m_YLoadings;
            default:
                return null;
        }
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public boolean hasLoadings() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractMultiResponsePLS, com.github.waikatodatamining.matrix.algorithm.AbstractAlgorithm, com.github.waikatodatamining.matrix.core.LoggingObject
    public void reset() {
        super.reset();
        this.m_XScores = null;
        this.m_YScores = null;
        this.m_XLoadings = null;
        this.m_YLoadings = null;
        this.m_XWeights = null;
        this.m_YWeights = null;
        this.m_Coef = null;
        this.m_X = null;
        this.m_NormYWeights = false;
        this.m_deflationMode = DeflationMode.REGRESSION;
        this.m_XRotations = null;
        this.m_YRotations = null;
        this.m_StandardizeX = new Standardize();
        this.m_StandardizeY = new Standardize();
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public Matrix getLoadings() {
        return this.m_XLoadings;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.AbstractPLS
    public boolean canPredict() {
        return true;
    }

    public Matrix getCoef() {
        return this.m_Coef;
    }
}
