package com.github.waikatodatamining.matrix.algorithms.glsw;

import com.github.waikatodatamining.matrix.core.algorithm.MatrixAlgorithm;
import com.github.waikatodatamining.matrix.core.exceptions.MatrixAlgorithmsException;
import com.github.waikatodatamining.matrix.core.matrix.Matrix;
import com.github.waikatodatamining.matrix.core.matrix.MatrixFactory;
import java.util.Comparator;
import java.util.stream.IntStream;

/* loaded from: input_file:com/github/waikatodatamining/matrix/algorithms/glsw/YGradientGLSW.class */
public class YGradientGLSW extends GLSW {
    private static final long serialVersionUID = 4080767826836437539L;

    /* loaded from: input_file:com/github/waikatodatamining/matrix/algorithms/glsw/YGradientGLSW$SavitzkyGolayFilter.class */
    public static class SavitzkyGolayFilter extends MatrixAlgorithm {
        private static final long serialVersionUID = 7783793644680234716L;
        protected static final double[] m_Coef = {0.2d, 0.1d, 0.0d, -0.1d, -0.2d};

        @Override // com.github.waikatodatamining.matrix.core.algorithm.MatrixAlgorithm
        protected Matrix doTransform(Matrix matrix) {
            Matrix extendMatrix = extendMatrix(matrix);
            Matrix zerosLike = MatrixFactory.zerosLike(extendMatrix);
            for (int i = 2; i < extendMatrix.numRows() - 2; i++) {
                zerosLike.setRow(i, smoothRow(i, extendMatrix));
            }
            return shrinkMatrix(zerosLike);
        }

        private Matrix shrinkMatrix(Matrix matrix) {
            return matrix.getSubMatrix(2, matrix.numRows() - 2, 0, matrix.numColumns());
        }

        protected Matrix extendMatrix(Matrix matrix) {
            Matrix row = matrix.getRow(0);
            Matrix row2 = matrix.getRow(matrix.numRows() - 1);
            return row.concatAlongRows(row).concatAlongRows(matrix).concatAlongRows(row2).concatAlongRows(row2);
        }

        protected Matrix smoothRow(int i, Matrix matrix) {
            Matrix zeros = MatrixFactory.zeros(1, matrix.numColumns());
            int length = m_Coef.length;
            for (int i2 = 0; i2 < length; i2++) {
                zeros = zeros.add(matrix.getRow(i - (i2 - 2)).mul(m_Coef[(length - 1) - i2]));
            }
            return zeros;
        }
    }

    @Override // com.github.waikatodatamining.matrix.algorithms.glsw.GLSW
    protected Matrix getCovarianceMatrix(Matrix matrix, Matrix matrix2) {
        double[] rawCopy1D = matrix2.toRawCopy1D();
        int[] array = IntStream.range(0, rawCopy1D.length).boxed().sorted(Comparator.comparingDouble(num -> {
            return rawCopy1D[num.intValue()];
        })).mapToInt(num2 -> {
            return num2.intValue();
        }).toArray();
        Matrix subMatrix = matrix.getSubMatrix(array, IntStream.range(0, matrix.numColumns()).toArray());
        Matrix subMatrix2 = matrix2.getSubMatrix(array, new int[]{0});
        SavitzkyGolayFilter savitzkyGolayFilter = new SavitzkyGolayFilter();
        Matrix transform = savitzkyGolayFilter.transform(subMatrix);
        Matrix transform2 = savitzkyGolayFilter.transform(subMatrix2);
        double asDouble = transform2.sub(transform2.mean(-1).asDouble()).powElementwise(2.0d).sum(-1).div(transform2.numRows() - 1).sqrt().asDouble();
        Matrix zeros = MatrixFactory.zeros(matrix2.numRows(), matrix2.numRows());
        for (int i = 0; i < transform2.numRows(); i++) {
            zeros.set(i, i, Math.pow(2.0d, ((-1.0d) * transform2.get(i, 0)) / asDouble));
        }
        return transform.t().mul(zeros.mul(zeros)).mul(transform);
    }

    @Override // com.github.waikatodatamining.matrix.algorithms.glsw.GLSW
    protected void check(Matrix matrix, Matrix matrix2) {
        if (matrix.numRows() != matrix2.numRows()) {
            throw new MatrixAlgorithmsException("Predictors and response must have the same number of rows!");
        }
    }
}
