package greycat.ml.common.matrix.blassolver;

import greycat.ml.common.matrix.TransposeType;
import greycat.ml.common.matrix.VolatileDMatrix;
import greycat.ml.common.matrix.blassolver.blas.Blas;
import greycat.struct.DMatrix;

/* loaded from: input_file:greycat/ml/common/matrix/blassolver/LU.class */
class LU {
    private DMatrix LU;
    private Blas _blas;
    private int[] piv;
    private boolean singular;

    public DMatrix getLU() {
        return this.LU;
    }

    public LU(int i, int i2, Blas blas) {
        this._blas = blas;
        this.LU = VolatileDMatrix.empty(i, i2);
        this.piv = new int[Math.min(i, i2)];
    }

    public static LU factorize(DMatrix dMatrix, Blas blas) {
        return new LU(dMatrix.rows(), dMatrix.columns(), blas).factor(dMatrix, false);
    }

    public LU factor(DMatrix dMatrix, boolean z) {
        if (z) {
            this.singular = false;
            int[] iArr = {0};
            this._blas.dgetrf(dMatrix.rows(), dMatrix.columns(), dMatrix.data(), 0, dMatrix.rows(), this.piv, 0, iArr);
            if (iArr[0] > 0) {
                this.singular = true;
            } else if (iArr[0] < 0) {
                throw new RuntimeException();
            }
            this.LU.fillWith(dMatrix.data());
            return this;
        }
        this.singular = false;
        DMatrix cloneFrom = VolatileDMatrix.cloneFrom(dMatrix);
        int[] iArr2 = {0};
        this._blas.dgetrf(cloneFrom.rows(), cloneFrom.columns(), cloneFrom.data(), 0, cloneFrom.rows(), this.piv, 0, iArr2);
        if (iArr2[0] > 0) {
            this.singular = true;
        } else if (iArr2[0] < 0) {
            throw new RuntimeException();
        }
        this.LU.fillWith(cloneFrom.data());
        return this;
    }

    public DMatrix getL() {
        int rows = this.LU.rows();
        int rows2 = this.LU.rows() < this.LU.columns() ? this.LU.rows() : this.LU.columns();
        VolatileDMatrix empty = VolatileDMatrix.empty(rows, rows2);
        for (int i = 0; i < rows2; i++) {
            empty.set(i, i, 1.0d);
            for (int i2 = 0; i2 < i; i2++) {
                empty.set(i, i2, this.LU.get(i, i2));
            }
        }
        if (rows > rows2) {
            for (int i3 = rows2; i3 < rows; i3++) {
                for (int i4 = 0; i4 < rows2; i4++) {
                    empty.set(i3, i4, this.LU.get(i3, i4));
                }
            }
        }
        return empty;
    }

    public DMatrix getU() {
        int rows = this.LU.rows() < this.LU.columns() ? this.LU.rows() : this.LU.columns();
        int columns = this.LU.columns();
        VolatileDMatrix empty = VolatileDMatrix.empty(rows, columns);
        for (int i = 0; i < rows; i++) {
            for (int i2 = i; i2 < columns; i2++) {
                empty.set(i, i2, this.LU.get(i, i2));
            }
        }
        return empty;
    }

    public int[] getPivots() {
        return this.piv;
    }

    public boolean isSingular() {
        return this.singular;
    }

    public DMatrix solve(DMatrix dMatrix) {
        return transSolve(dMatrix, TransposeType.NOTRANSPOSE);
    }

    public DMatrix transSolve(DMatrix dMatrix, TransposeType transposeType) {
        if (dMatrix.rows() != this.LU.rows()) {
            throw new RuntimeException("B.numRows() != LU.numRows()");
        }
        int[] iArr = new int[1];
        this._blas.dgetrs(transposeType, this.LU.rows(), dMatrix.columns(), this.LU.data(), 0, this.LU.rows(), this.piv, 0, dMatrix.data(), 0, dMatrix.rows(), iArr);
        if (iArr[0] < 0) {
            throw new RuntimeException();
        }
        return dMatrix;
    }

    public boolean invert(DMatrix dMatrix) {
        int[] iArr = {0};
        this._blas.dgetrf(dMatrix.rows(), dMatrix.columns(), dMatrix.data(), 0, dMatrix.rows(), this.piv, 0, iArr);
        if (iArr[0] > 0) {
            this.singular = true;
            return false;
        }
        if (iArr[0] < 0) {
            throw new RuntimeException();
        }
        int rows = dMatrix.rows() * dMatrix.rows();
        double[] dArr = new double[rows];
        for (int i = 0; i < rows; i++) {
            dArr[i] = 0.0d;
        }
        this._blas.dgetri(dMatrix.rows(), dMatrix.data(), 0, dMatrix.rows(), this.piv, 0, dArr, 0, rows, iArr);
        return iArr[0] == 0;
    }
}
