package greycat.ml.common.matrix;

import greycat.struct.DMatrix;
import java.util.Random;

/* loaded from: input_file:greycat/ml/common/matrix/MatrixOps.class */
public class MatrixOps {
    private static MatrixEngine _defaultEngine = null;

    public static MatrixEngine defaultEngine() {
        if (_defaultEngine == null) {
            _defaultEngine = new HybridMatrixEngine();
        }
        return _defaultEngine;
    }

    public static void setDefaultEngine(MatrixEngine matrixEngine) {
        _defaultEngine = matrixEngine;
    }

    public static void copyMatrix(DMatrix dMatrix, DMatrix dMatrix2) {
        for (int i = 0; i < dMatrix.rows(); i++) {
            for (int i2 = 0; i2 < dMatrix.columns(); i2++) {
                dMatrix2.set(i, i2, dMatrix.get(i, i2));
            }
        }
    }

    public static DMatrix fillWithRandom(DMatrix dMatrix, Random random, double d, double d2) {
        int length = dMatrix.length();
        double d3 = d2 - d;
        for (int i = 0; i < length; i++) {
            dMatrix.unsafeSet(i, (random.nextDouble() * d3) + d);
        }
        return dMatrix;
    }

    public static DMatrix fillWithRandomStd(DMatrix dMatrix, Random random, double d) {
        int length = dMatrix.length();
        for (int i = 0; i < length; i++) {
            dMatrix.unsafeSet(i, random.nextGaussian() * d);
        }
        return dMatrix;
    }

    public static DMatrix multiply(DMatrix dMatrix, DMatrix dMatrix2) {
        return defaultEngine().multiplyTransposeAlphaBeta(TransposeType.NOTRANSPOSE, 1.0d, dMatrix, TransposeType.NOTRANSPOSE, dMatrix2, 0.0d, null);
    }

    public static DMatrix multiplyTranspose(TransposeType transposeType, DMatrix dMatrix, TransposeType transposeType2, DMatrix dMatrix2) {
        return defaultEngine().multiplyTransposeAlphaBeta(transposeType, 1.0d, dMatrix, transposeType2, dMatrix2, 0.0d, null);
    }

    public static DMatrix multiplyTransposeAlpha(TransposeType transposeType, double d, DMatrix dMatrix, TransposeType transposeType2, DMatrix dMatrix2) {
        return defaultEngine().multiplyTransposeAlphaBeta(transposeType, d, dMatrix, transposeType2, dMatrix2, 0.0d, null);
    }

    public static DMatrix multiplyTransposeAlphaBeta(TransposeType transposeType, double d, DMatrix dMatrix, TransposeType transposeType2, DMatrix dMatrix2, double d2, DMatrix dMatrix3) {
        return defaultEngine().multiplyTransposeAlphaBeta(transposeType, d, dMatrix, transposeType2, dMatrix2, d2, dMatrix3);
    }

    public static DMatrix invert(DMatrix dMatrix, boolean z) {
        return defaultEngine().invert(dMatrix, z);
    }

    public static DMatrix pinv(DMatrix dMatrix, boolean z) {
        return defaultEngine().pinv(dMatrix, z);
    }

    public static DMatrix transpose(DMatrix dMatrix) {
        VolatileDMatrix empty = VolatileDMatrix.empty(dMatrix.columns(), dMatrix.rows());
        if (dMatrix.columns() == dMatrix.rows()) {
            transposeSquare(dMatrix, empty);
        } else if (dMatrix.columns() <= 375 || dMatrix.rows() <= 375) {
            transposeStandard(dMatrix, empty);
        } else {
            transposeBlock(dMatrix, empty);
        }
        return empty;
    }

    private static void transposeSquare(DMatrix dMatrix, DMatrix dMatrix2) {
        int i = 1;
        int columns = dMatrix.columns();
        for (int i2 = 0; i2 < dMatrix.rows(); i2++) {
            int columns2 = ((i2 + 1) * dMatrix.columns()) + i2;
            int columns3 = i2 * (dMatrix.columns() + 1);
            dMatrix2.unsafeSet(columns3, dMatrix.unsafeGet(columns3));
            while (i < columns) {
                dMatrix2.unsafeSet(i, dMatrix.unsafeGet(columns2));
                dMatrix2.unsafeSet(columns2, dMatrix.unsafeGet(i));
                columns2 += dMatrix.columns();
                i++;
            }
            i += i2 + 2;
            columns += dMatrix.columns();
        }
    }

    private static void transposeStandard(DMatrix dMatrix, DMatrix dMatrix2) {
        int i = 0;
        for (int i2 = 0; i2 < dMatrix2.columns(); i2++) {
            int i3 = i2;
            int rows = i + dMatrix2.rows();
            while (i < rows) {
                int i4 = i;
                i++;
                dMatrix2.unsafeSet(i4, dMatrix.unsafeGet(i3));
                i3 += dMatrix.rows();
            }
        }
    }

    private static void transposeBlock(DMatrix dMatrix, DMatrix dMatrix2) {
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= dMatrix.columns()) {
                return;
            }
            int min = Math.min(60, dMatrix.columns() - i2);
            int rows = i2 * dMatrix.rows();
            int i3 = i2;
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 < dMatrix.rows()) {
                    int min2 = rows + Math.min(60, dMatrix.rows() - i5);
                    while (rows < min2) {
                        int i6 = rows;
                        int i7 = i3;
                        int i8 = i7 + min;
                        while (i7 < i8) {
                            dMatrix2.unsafeSet(i7, dMatrix.unsafeGet(i6));
                            i6 += dMatrix.rows();
                            i7++;
                        }
                        i3 += dMatrix2.rows();
                        rows++;
                    }
                    i4 = i5 + 60;
                }
            }
            i = i2 + 60;
        }
    }

    public static boolean testDimensionsAB(TransposeType transposeType, TransposeType transposeType2, DMatrix dMatrix, DMatrix dMatrix2) {
        return transposeType.equals(TransposeType.NOTRANSPOSE) ? transposeType2.equals(TransposeType.NOTRANSPOSE) ? dMatrix.columns() == dMatrix2.rows() : dMatrix.columns() == dMatrix2.columns() : transposeType2.equals(TransposeType.NOTRANSPOSE) ? dMatrix.rows() == dMatrix2.rows() : dMatrix.rows() == dMatrix2.columns();
    }

    public static boolean testDim(DMatrix dMatrix, DMatrix dMatrix2) {
        if (dMatrix.rows() == dMatrix2.rows() && dMatrix.columns() == dMatrix2.columns()) {
            return true;
        }
        throw new RuntimeException("Matrices original and destination have different dimensions");
    }

    public static DMatrix sub(DMatrix dMatrix, DMatrix dMatrix2) {
        testDim(dMatrix, dMatrix2);
        VolatileDMatrix empty = VolatileDMatrix.empty(dMatrix.rows(), dMatrix.columns());
        int length = dMatrix.length();
        for (int i = 0; i < length; i++) {
            empty.unsafeSet(i, dMatrix.unsafeGet(i) - dMatrix2.unsafeGet(i));
        }
        return empty;
    }

    public static void scaleInPlace(double d, DMatrix dMatrix) {
        if (d == 0.0d) {
            dMatrix.fill(0.0d);
            return;
        }
        for (int i = 0; i < dMatrix.rows() * dMatrix.columns(); i++) {
            dMatrix.unsafeSet(i, d * dMatrix.unsafeGet(i));
        }
    }

    public static void addInPlace(DMatrix dMatrix, double d, DMatrix dMatrix2, double d2) {
        testDim(dMatrix, dMatrix2);
        int length = dMatrix.length();
        if (d == 0.0d) {
            if (d2 == 0.0d) {
                dMatrix.fill(0.0d);
                return;
            }
            if (d2 == 1.0d) {
                dMatrix.fillWith(dMatrix2.data());
                return;
            }
            for (int i = 0; i < length; i++) {
                dMatrix.unsafeSet(i, dMatrix2.unsafeGet(i) * d2);
            }
            return;
        }
        if (d == 1.0d) {
            if (d2 != 0.0d) {
                if (d2 == 1.0d) {
                    for (int i2 = 0; i2 < length; i2++) {
                        dMatrix.unsafeSet(i2, dMatrix.unsafeGet(i2) + dMatrix2.unsafeGet(i2));
                    }
                    return;
                }
                for (int i3 = 0; i3 < length; i3++) {
                    dMatrix.unsafeSet(i3, dMatrix.unsafeGet(i3) + (dMatrix2.unsafeGet(i3) * d2));
                }
                return;
            }
            return;
        }
        if (d2 == 0.0d) {
            for (int i4 = 0; i4 < length; i4++) {
                dMatrix.unsafeSet(i4, dMatrix.unsafeGet(i4) * d);
            }
            return;
        }
        if (d2 == 1.0d) {
            for (int i5 = 0; i5 < length; i5++) {
                dMatrix.unsafeSet(i5, (dMatrix.unsafeGet(i5) * d) + dMatrix2.unsafeGet(i5));
            }
            return;
        }
        for (int i6 = 0; i6 < length; i6++) {
            dMatrix.unsafeSet(i6, (dMatrix.unsafeGet(i6) * d) + (dMatrix2.unsafeGet(i6) * d2));
        }
    }

    public static void copy(DMatrix dMatrix, DMatrix dMatrix2) {
        int length = dMatrix.length();
        for (int i = 0; i < length; i++) {
            dMatrix2.unsafeSet(i, dMatrix.unsafeGet(i));
        }
    }

    public static void addtoMatrix(DMatrix dMatrix, DMatrix dMatrix2) {
        testDim(dMatrix, dMatrix2);
        int length = dMatrix.length();
        for (int i = 0; i < length; i++) {
            dMatrix.unsafeSet(i, dMatrix.unsafeGet(i) + dMatrix2.unsafeGet(i));
        }
    }

    public static void scaleThenAddtoMatrix(DMatrix dMatrix, DMatrix dMatrix2, double d) {
        testDim(dMatrix, dMatrix2);
        int length = dMatrix.length();
        for (int i = 0; i < length; i++) {
            dMatrix.unsafeSet(i, dMatrix.unsafeGet(i) + (dMatrix2.unsafeGet(i) * d));
        }
    }

    public static DMatrix add(DMatrix dMatrix, DMatrix dMatrix2) {
        testDim(dMatrix, dMatrix2);
        VolatileDMatrix empty = VolatileDMatrix.empty(dMatrix.rows(), dMatrix.columns());
        int length = dMatrix.length();
        for (int i = 0; i < length; i++) {
            empty.unsafeSet(i, dMatrix.unsafeGet(i) + dMatrix2.unsafeGet(i));
        }
        return empty;
    }

    public static DMatrix HadamardMult(DMatrix dMatrix, DMatrix dMatrix2) {
        testDim(dMatrix, dMatrix2);
        VolatileDMatrix empty = VolatileDMatrix.empty(dMatrix.rows(), dMatrix.columns());
        int length = dMatrix.length();
        for (int i = 0; i < length; i++) {
            empty.unsafeSet(i, dMatrix.unsafeGet(i) * dMatrix2.unsafeGet(i));
        }
        return empty;
    }
}
