package greycat.ml.common.matrix.blassolver;

import greycat.ml.common.matrix.MatrixOps;
import greycat.ml.common.matrix.TransposeType;
import greycat.ml.common.matrix.VolatileDMatrix;
import greycat.ml.common.matrix.blassolver.blas.Blas;
import greycat.struct.DMatrix;

/* loaded from: input_file:greycat/ml/common/matrix/blassolver/QR.class */
class QR {
    private DMatrix Q;
    private DMatrix R;
    private Blas _blas;
    int m;
    int n;
    int k;
    double[] work;
    double[] workGen;
    double[] tau;

    public QR(int i, int i2, Blas blas) {
        this._blas = blas;
        if (i2 > i) {
            throw new RuntimeException("n > m");
        }
        this.m = i;
        this.n = i2;
        this.k = Math.min(this.m, this.n);
        this.tau = new double[this.k];
        this.R = VolatileDMatrix.empty(this.n, this.n);
    }

    public static QR factorize(DMatrix dMatrix, boolean z, Blas blas) {
        return new QR(dMatrix.rows(), dMatrix.columns(), blas).factor(dMatrix, z);
    }

    public QR factor(DMatrix dMatrix, boolean z) {
        DMatrix cloneFrom = !z ? VolatileDMatrix.cloneFrom(dMatrix) : dMatrix;
        this.work = new double[1];
        int[] iArr = {0};
        this._blas.dgeqrf(this.m, this.n, new double[0], 0, this.m, new double[0], 0, this.work, 0, -1, iArr);
        this.work = new double[Math.max(1, iArr[0] != 0 ? this.n : (int) this.work[0])];
        this.workGen = new double[1];
        iArr[0] = 0;
        this._blas.dorgqr(this.m, this.n, this.k, new double[0], 0, this.m, new double[0], 0, this.workGen, 0, -1, iArr);
        this.workGen = new double[Math.max(1, iArr[0] != 0 ? this.n : (int) this.workGen[0])];
        iArr[0] = 0;
        this._blas.dgeqrf(this.m, this.n, cloneFrom.data(), 0, this.m, this.tau, 0, this.work, 0, this.work.length, iArr);
        if (iArr[0] < 0) {
            throw new RuntimeException("" + iArr[0]);
        }
        for (int i = 0; i < cloneFrom.columns(); i++) {
            for (int i2 = 0; i2 <= i; i2++) {
                this.R.set(i2, i, cloneFrom.get(i2, i));
            }
        }
        iArr[0] = 0;
        this._blas.dorgqr(this.m, this.n, this.k, cloneFrom.data(), 0, this.m, this.tau, 0, this.workGen, 0, this.workGen.length, iArr);
        if (iArr[0] < 0) {
            throw new RuntimeException();
        }
        this.Q = cloneFrom;
        return this;
    }

    public void solve(DMatrix dMatrix, DMatrix dMatrix2) {
        int columns = dMatrix.columns();
        VolatileDMatrix empty = VolatileDMatrix.empty(this.m, 1);
        for (int i = 0; i < columns; i++) {
            for (int i2 = 0; i2 < this.m; i2++) {
                empty.unsafeSet(i2, dMatrix.get(i2, i));
            }
            DMatrix multiplyTranspose = MatrixOps.multiplyTranspose(TransposeType.TRANSPOSE, this.Q, TransposeType.NOTRANSPOSE, empty);
            solveU(this.R, multiplyTranspose.data(), this.n, this.m);
            for (int i3 = 0; i3 < this.n; i3++) {
                dMatrix2.set(i3, i, multiplyTranspose.unsafeGet(i3));
            }
        }
    }

    private void solveU(DMatrix dMatrix, double[] dArr, int i, int i2) {
        for (int i3 = i - 1; i3 >= 0; i3--) {
            double d = dArr[i3];
            for (int i4 = i3 + 1; i4 < i; i4++) {
                d -= dMatrix.get(i3, i4) * dArr[i4];
            }
            dArr[i3] = d / dMatrix.get(i3, i3);
        }
    }

    public DMatrix getR() {
        return this.R;
    }

    public DMatrix getQ() {
        return this.Q;
    }
}
