package org.apache.mahout.math.als;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.util.Iterator;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.Vector;

/* loaded from: input_file:libarx-3.7.1.jar:org/apache/mahout/math/als/AlternatingLeastSquaresSolver.class */
public final class AlternatingLeastSquaresSolver {
    private AlternatingLeastSquaresSolver() {
    }

    public static Vector solve(Iterable<Vector> iterable, Vector vector, double d, int i) {
        Preconditions.checkNotNull(iterable, "Feature Vectors cannot be null");
        Preconditions.checkArgument(!Iterables.isEmpty(iterable));
        Preconditions.checkNotNull(vector, "Rating Vector cannot be null");
        Preconditions.checkArgument(vector.getNumNondefaultElements() > 0, "Rating Vector cannot be empty");
        Preconditions.checkArgument(Iterables.size(iterable) == vector.getNumNondefaultElements());
        int numNondefaultElements = vector.getNumNondefaultElements();
        Matrix createMiIi = createMiIi(iterable, i);
        return solve(miTimesMiTransposePlusLambdaTimesNuiTimesE(createMiIi, d, numNondefaultElements), createMiIi.times(createRiIiMaybeTransposed(vector)));
    }

    private static Vector solve(Matrix matrix, Matrix matrix2) {
        return new QRDecomposition(matrix).solve(matrix2).viewColumn(0);
    }

    static Matrix addLambdaTimesNuiTimesE(Matrix matrix, double d, int i) {
        Preconditions.checkArgument(matrix.numCols() == matrix.numRows(), "Must be a Square Matrix");
        double d2 = d * i;
        int numCols = matrix.numCols();
        for (int i2 = 0; i2 < numCols; i2++) {
            matrix.setQuick(i2, i2, matrix.getQuick(i2, i2) + d2);
        }
        return matrix;
    }

    private static Matrix miTimesMiTransposePlusLambdaTimesNuiTimesE(Matrix matrix, double d, int i) {
        double d2 = d * i;
        int numRows = matrix.numRows();
        double[][] dArr = new double[numRows][numRows];
        for (int i2 = 0; i2 < numRows; i2++) {
            for (int i3 = i2; i3 < numRows; i3++) {
                double dot = matrix.viewRow(i2).dot(matrix.viewRow(i3));
                if (i2 != i3) {
                    dArr[i2][i3] = dot;
                    dArr[i3][i2] = dot;
                } else {
                    dArr[i2][i2] = dot + d2;
                }
            }
        }
        return new DenseMatrix(dArr, true);
    }

    static Matrix createMiIi(Iterable<Vector> iterable, int i) {
        double[][] dArr = new double[i][Iterables.size(iterable)];
        int i2 = 0;
        for (Vector vector : iterable) {
            for (int i3 = 0; i3 < i; i3++) {
                dArr[i3][i2] = vector.getQuick(i3);
            }
            i2++;
        }
        return new DenseMatrix(dArr, true);
    }

    static Matrix createRiIiMaybeTransposed(Vector vector) {
        Preconditions.checkArgument(vector.isSequentialAccess(), "Ratings should be iterable in Index or Sequential Order");
        double[][] dArr = new double[vector.getNumNondefaultElements()][1];
        int i = 0;
        Iterator<Vector.Element> it = vector.nonZeroes().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2][0] = it.next().get();
        }
        return new DenseMatrix(dArr, true);
    }
}
