package com.github.waikatodatamining.matrix.algorithm.pls;

import com.github.waikatodatamining.matrix.core.Matrix;
import com.github.waikatodatamining.matrix.core.MatrixFactory;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithm/pls/PRM.class */
public class PRM extends AbstractSingleResponsePLS {
    private static final long serialVersionUID = 4864232250283829109L;
    protected double m_Tol;
    protected int m_MaxIter;
    protected double m_C;
    protected Matrix m_Wr;
    protected Matrix m_Wx;
    protected Matrix m_T;
    protected Matrix m_Gamma;
    protected int m_NumSimplsCoefficients;
    protected Matrix m_FinalRegressionCoefficients;
    protected SIMPLS m_Simpls;

    public void setNumSimplsCoefficients(int i) {
        this.m_NumSimplsCoefficients = i;
        reset();
    }

    public int getNumSimplsCoefficients() {
        return this.m_NumSimplsCoefficients;
    }

    public int getMaxIter() {
        return this.m_MaxIter;
    }

    public void setMaxIter(int i) {
        if (i < 0) {
            this.m_Logger.warning("Maximum iterations parameter must be positive but was " + i + ".");
        } else {
            this.m_MaxIter = i;
            reset();
        }
    }

    public double getTol() {
        return this.m_Tol;
    }

    public void setTol(double d) {
        if (d < 0.0d) {
            this.m_Logger.warning("Tolerance parameter must be positive but was " + d + ".");
        } else {
            this.m_Tol = d;
            reset();
        }
    }

    public double getC() {
        return this.m_C;
    }

    public void setC(double d) {
        if (Math.abs(d) < 1.0E-10d) {
            this.m_Logger.warning("Parameter c must not be zero!");
        } else {
            this.m_C = d;
            reset();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractSingleResponsePLS, com.github.waikatodatamining.matrix.algorithm.AbstractAlgorithm, com.github.waikatodatamining.matrix.core.LoggingObject
    public void reset() {
        super.reset();
        this.m_Wr = null;
        this.m_Wx = null;
        this.m_FinalRegressionCoefficients = null;
        this.m_Gamma = null;
        this.m_T = null;
        this.m_Simpls = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS, com.github.waikatodatamining.matrix.core.LoggingObject
    public void initialize() {
        super.initialize();
        setC(4.0d);
        setTol(1.0E-6d);
        setMaxIter(500);
        setNumSimplsCoefficients(-1);
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public String[] getMatrixNames() {
        return new String[]{"B", "Wr", "Wx", "W"};
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public Matrix getMatrix(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 66:
                if (str.equals("B")) {
                    z = false;
                    break;
                }
                break;
            case 87:
                if (str.equals("W")) {
                    z = 3;
                    break;
                }
                break;
            case 2811:
                if (str.equals("Wr")) {
                    z = true;
                    break;
                }
                break;
            case 2817:
                if (str.equals("Wx")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case KernelPLS.SEED /* 0 */:
                return this.m_FinalRegressionCoefficients;
            case true:
                return this.m_Wr;
            case true:
                return this.m_Wx;
            case true:
                return this.m_Wr.mulElementwise(this.m_Wx);
            default:
                return null;
        }
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public boolean hasLoadings() {
        return false;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public Matrix getLoadings() {
        return null;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS, com.github.waikatodatamining.matrix.algorithm.AbstractAlgorithm
    public String toString() {
        return "";
    }

    protected double fairFunction(double d, double d2) {
        return 1.0d / StrictMath.pow(1.0d + StrictMath.abs(d / d2), 2.0d);
    }

    protected void initWeights(Matrix matrix, Matrix matrix2) {
        updateResidualWeights(matrix, matrix2);
        updateLeverageWeights(matrix);
    }

    private void updateLeverageWeights(Matrix matrix) {
        int numRows = matrix.numRows();
        this.m_Wx = MatrixFactory.zeros(numRows, 1);
        Matrix cdist = cdist(matrix, geometricMedian(matrix));
        double median = cdist.median();
        for (int i = 0; i < numRows; i++) {
            this.m_Wx.set(i, 0, fairFunction(cdist.get(i, 0) / median, this.m_C));
        }
    }

    protected void updateResidualWeights(Matrix matrix, Matrix matrix2) {
        int numRows = matrix.numRows();
        this.m_Wr = MatrixFactory.zeros(numRows, 1);
        Matrix zeros = MatrixFactory.zeros(numRows, 1);
        boolean z = this.m_T == null && this.m_Gamma == null;
        double median = matrix2.median();
        for (int i = 0; i < numRows; i++) {
            if (!z) {
                median = this.m_T.getRow(i).mul(this.m_Gamma).asDouble();
            }
            zeros.set(i, 0, matrix2.get(i, 0) - median);
        }
        zeros.divi(medianAbsoluteDeviation(zeros));
        for (int i2 = 0; i2 < numRows; i2++) {
            this.m_Wr.set(i2, 0, fairFunction(zeros.get(i2, 0), this.m_C));
        }
    }

    public double medianAbsoluteDeviation(Matrix matrix) {
        return matrix.sub(matrix.median()).abs().median();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public String doPerformInitialization(Matrix matrix, Matrix matrix2) throws Exception {
        Matrix matrix3 = matrix;
        Matrix matrix4 = null;
        boolean z = matrix3.numColumns() > matrix3.numRows();
        if (z) {
            Matrix t = matrix3.t();
            matrix4 = t.svdU();
            matrix3 = t.svdV().mul(t.svdS());
        }
        initWeights(matrix3, matrix2);
        int numComponents = getNumComponents();
        this.m_Gamma = MatrixFactory.zeros(numComponents, 1);
        int i = 0;
        do {
            Matrix reweightedMatrix = getReweightedMatrix(matrix3);
            Matrix reweightedMatrix2 = getReweightedMatrix(matrix2);
            this.m_Simpls = new SIMPLS();
            this.m_Simpls.setNumCoefficients(this.m_NumSimplsCoefficients);
            this.m_Simpls.setNumComponents(numComponents);
            this.m_Simpls.initialize(reweightedMatrix, reweightedMatrix2);
            Matrix copy = this.m_Gamma.copy();
            this.m_T = this.m_Simpls.transform(reweightedMatrix);
            this.m_Gamma = this.m_Simpls.getMatrix("Q").t();
            for (int i2 = 0; i2 < this.m_T.numRows(); i2++) {
                this.m_T.setRow(i2, this.m_T.getRow(i2).div(Math.sqrt(getCombinedWeight(i2))));
            }
            updateResidualWeights(reweightedMatrix, reweightedMatrix2);
            updateLeverageWeights(this.m_T);
            i++;
            if (this.m_Gamma.sub(copy).norm2squared() >= this.m_Tol) {
                break;
            }
        } while (i < this.m_MaxIter);
        this.m_FinalRegressionCoefficients = this.m_Simpls.getMatrix("B");
        if (!z) {
            return null;
        }
        this.m_FinalRegressionCoefficients = matrix4.mul(this.m_FinalRegressionCoefficients);
        return null;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public boolean canPredict() {
        return true;
    }

    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    protected Matrix doTransform(Matrix matrix) throws Exception {
        return matrix.mul(this.m_Simpls.getMatrix("W"));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.waikatodatamining.matrix.algorithm.pls.AbstractPLS
    public Matrix doPerformPredictions(Matrix matrix) throws Exception {
        return matrix.mul(this.m_FinalRegressionCoefficients);
    }

    protected Matrix getReweightedMatrix(Matrix matrix) {
        return matrix.copy().scaleByColumnVector(this.m_Wr.mulElementwise(this.m_Wx).sqrt());
    }

    protected double getCombinedWeight(int i) {
        return this.m_Wx.getRow(i).asDouble() * this.m_Wr.getRow(i).asDouble();
    }

    protected Matrix geometricMedian(Matrix matrix) {
        Matrix mean = matrix.mean(0);
        for (int i = 0; i < this.m_MaxIter; i++) {
            Matrix modifyEach = cdist(matrix, mean).modifyEach(d -> {
                return Math.abs(d.doubleValue()) < 1.0E-10d ? Double.valueOf(10.0d) : Double.valueOf(1.0d / d.doubleValue());
            });
            Matrix div = matrix.scaleByColumnVector(modifyEach).sum(0).div(modifyEach.sum(0).asDouble());
            double norm2squared = div.sub(mean).norm2squared();
            mean = div;
            if (norm2squared < this.m_Tol) {
                break;
            }
        }
        return mean;
    }

    protected Matrix cdist(Matrix matrix, Matrix matrix2) {
        Matrix zeros = MatrixFactory.zeros(matrix.numRows(), 1);
        for (int i = 0; i < matrix.numRows(); i++) {
            zeros.set(i, 0, matrix.getRow(i).sub(matrix2).norm2());
        }
        return zeros;
    }
}
