package greycat.ml.common.matrix.blassolver;

import greycat.ml.common.matrix.MatrixEngine;
import greycat.ml.common.matrix.MatrixOps;
import greycat.ml.common.matrix.SVDDecompose;
import greycat.ml.common.matrix.TransposeType;
import greycat.ml.common.matrix.VolatileDMatrix;
import greycat.ml.common.matrix.blassolver.blas.Blas;
import greycat.ml.common.matrix.blassolver.blas.NetlibBlas;
import greycat.struct.DMatrix;

/* loaded from: input_file:greycat/ml/common/matrix/blassolver/BlasMatrixEngine.class */
public class BlasMatrixEngine implements MatrixEngine {
    private Blas _blas = new NetlibBlas();

    public void setBlas(Blas blas) {
        this._blas = blas;
    }

    public Blas getBlas() {
        return this._blas;
    }

    @Override // greycat.ml.common.matrix.MatrixEngine
    public DMatrix multiplyTransposeAlphaBeta(TransposeType transposeType, double d, DMatrix dMatrix, TransposeType transposeType2, DMatrix dMatrix2, double d2, DMatrix dMatrix3) {
        int rows;
        if (!MatrixOps.testDimensionsAB(transposeType, transposeType2, dMatrix, dMatrix2)) {
            throw new RuntimeException("Dimensions mismatch between A,B and C");
        }
        int[] iArr = new int[2];
        if (transposeType.equals(TransposeType.NOTRANSPOSE)) {
            rows = dMatrix.columns();
            if (transposeType2.equals(TransposeType.NOTRANSPOSE)) {
                iArr[0] = dMatrix.rows();
                iArr[1] = dMatrix2.columns();
            } else {
                iArr[0] = dMatrix.rows();
                iArr[1] = dMatrix2.rows();
            }
        } else {
            rows = dMatrix.rows();
            if (transposeType2.equals(TransposeType.NOTRANSPOSE)) {
                iArr[0] = dMatrix.columns();
                iArr[1] = dMatrix2.columns();
            } else {
                iArr[0] = dMatrix.columns();
                iArr[1] = dMatrix2.rows();
            }
        }
        if (d2 == 0.0d || dMatrix3 == null) {
            dMatrix3 = VolatileDMatrix.empty(iArr[0], iArr[1]);
        }
        this._blas.dgemm(transposeType, transposeType2, dMatrix3.rows(), dMatrix3.columns(), rows, d, dMatrix.data(), 0, dMatrix.rows(), dMatrix2.data(), 0, dMatrix2.rows(), d2, dMatrix3.data(), 0, dMatrix3.rows());
        return dMatrix3;
    }

    @Override // greycat.ml.common.matrix.MatrixEngine
    public DMatrix invert(DMatrix dMatrix, boolean z) {
        if (dMatrix.rows() != dMatrix.columns()) {
            return null;
        }
        if (z) {
            if (new LU(dMatrix.rows(), dMatrix.columns(), this._blas).invert(dMatrix)) {
                return dMatrix;
            }
            return null;
        }
        VolatileDMatrix empty = VolatileDMatrix.empty(dMatrix.rows(), dMatrix.columns());
        VolatileDMatrix empty2 = VolatileDMatrix.empty(dMatrix.rows(), dMatrix.columns());
        System.arraycopy(dMatrix.data(), 0, empty2.data(), 0, dMatrix.columns() * dMatrix.rows());
        if (!new LU(empty2.rows(), empty2.columns(), this._blas).invert(empty2)) {
            return null;
        }
        empty.fillWith(empty2.data());
        return empty;
    }

    @Override // greycat.ml.common.matrix.MatrixEngine
    public DMatrix pinv(DMatrix dMatrix, boolean z) {
        return solve(dMatrix, VolatileDMatrix.identity(dMatrix.rows(), dMatrix.rows()));
    }

    @Override // greycat.ml.common.matrix.MatrixEngine
    public DMatrix solveQR(DMatrix dMatrix, DMatrix dMatrix2, boolean z, TransposeType transposeType) {
        QR factorize = QR.factorize(dMatrix, z, this._blas);
        VolatileDMatrix empty = VolatileDMatrix.empty(dMatrix.columns(), dMatrix2.columns());
        if (transposeType != TransposeType.NOTRANSPOSE) {
            dMatrix2 = MatrixOps.transpose(dMatrix2);
        }
        factorize.solve(dMatrix2, empty);
        return empty;
    }

    @Override // greycat.ml.common.matrix.MatrixEngine
    public SVDDecompose decomposeSVD(DMatrix dMatrix, boolean z) {
        SVD svd = new SVD(dMatrix.rows(), dMatrix.columns(), this._blas);
        svd.factor(dMatrix, z);
        return svd;
    }

    @Override // greycat.ml.common.matrix.MatrixEngine
    public DMatrix solveLU(DMatrix dMatrix, DMatrix dMatrix2, boolean z, TransposeType transposeType) {
        if (z) {
            LU lu = new LU(dMatrix.rows(), dMatrix.columns(), this._blas);
            lu.factor(dMatrix, true);
            if (lu.isSingular()) {
                return null;
            }
            lu.transSolve(dMatrix2, transposeType);
            return dMatrix2;
        }
        VolatileDMatrix empty = VolatileDMatrix.empty(dMatrix.rows(), dMatrix.columns());
        System.arraycopy(dMatrix.data(), 0, empty.data(), 0, dMatrix.columns() * dMatrix.rows());
        LU lu2 = new LU(empty.rows(), empty.columns(), this._blas);
        lu2.factor(empty, true);
        if (lu2.isSingular()) {
            return null;
        }
        VolatileDMatrix empty2 = VolatileDMatrix.empty(dMatrix2.rows(), dMatrix2.columns());
        System.arraycopy(dMatrix2.data(), 0, empty2.data(), 0, dMatrix2.columns() * dMatrix2.rows());
        lu2.transSolve(empty2, transposeType);
        return empty2;
    }

    @Override // greycat.ml.common.matrix.MatrixEngine
    public DMatrix solve(DMatrix dMatrix, DMatrix dMatrix2) {
        return dMatrix.rows() == dMatrix.columns() ? new LU(dMatrix.rows(), dMatrix.columns(), this._blas).factor(dMatrix, false).solve(dMatrix2) : solveQR(dMatrix, dMatrix2, false, TransposeType.NOTRANSPOSE);
    }
}
