package math.linalg;

import java.util.Arrays;
import math.gemm.Dgemm;
import math.gemm.Trans;
import math.solve.LinearEquationsSolver;

/* loaded from: input_file:math/linalg/DMatrix.class */
public class DMatrix {
    protected final int rows;
    protected final int cols;
    protected final double[] a;

    public static DMatrix identity(int i) {
        DMatrix dMatrix = new DMatrix(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            dMatrix.setUnsafe(i2, i2, 1.0d);
        }
        return dMatrix;
    }

    public static DMatrix diag(int i, double d) {
        return identity(i).scaleInplace(d);
    }

    public DMatrix(int i, int i2) {
        this.rows = i;
        this.cols = i2;
        this.a = new double[checkArrayLength(i, i2)];
    }

    public DMatrix(DMatrix dMatrix) {
        this(dMatrix.rows, dMatrix.cols, Arrays.copyOf(dMatrix.a, dMatrix.a.length));
    }

    protected DMatrix(int i, int i2, double[] dArr) {
        this.rows = i;
        this.cols = i2;
        this.a = dArr;
    }

    public DMatrix copy() {
        return new DMatrix(this);
    }

    public int numColumns() {
        return this.cols;
    }

    public int numRows() {
        return this.rows;
    }

    public boolean isSquareMatrix() {
        return this.rows == this.cols;
    }

    public double get(int i, int i2) {
        checkIndex(i, i2);
        return this.a[idx(i, i2)];
    }

    public double getUnsafe(int i, int i2) {
        return this.a[idx(i, i2)];
    }

    public DMatrix set(int i, int i2, double d) {
        checkIndex(i, i2);
        this.a[idx(i, i2)] = d;
        return this;
    }

    public void setUnsafe(int i, int i2, double d) {
        this.a[idx(i, i2)] = d;
    }

    public double[] getArrayUnsafe() {
        return this.a;
    }

    public DMatrix scale(double d) {
        return scale(d, new DMatrix(this.rows, this.cols));
    }

    public DMatrix scaleInplace(double d) {
        return scale(d, this);
    }

    private DMatrix scale(double d, DMatrix dMatrix) {
        double[] dArr = this.a;
        double[] dArr2 = dMatrix.a;
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = d * dArr[i];
        }
        return dMatrix;
    }

    public DMatrix abs() {
        return abs(new DMatrix(this.rows, this.cols));
    }

    public DMatrix absInplace() {
        return abs(this);
    }

    private DMatrix abs(DMatrix dMatrix) {
        double[] dArr = this.a;
        double[] dArr2 = dMatrix.a;
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = Math.abs(dArr[i]);
        }
        return dMatrix;
    }

    public DMatrix transpose() {
        if (this.rows == 1 || this.cols == 1) {
            return new DMatrix(this.cols, this.rows, Arrays.copyOf(this.a, this.a.length));
        }
        DMatrix dMatrix = new DMatrix(this.cols, this.rows);
        int i = this.cols;
        int i2 = this.rows;
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                dMatrix.setUnsafe(i3, i4, getUnsafe(i4, i3));
            }
        }
        return dMatrix;
    }

    public DMatrix inverse() {
        if (!isSquareMatrix()) {
            throw new IllegalArgumentException("The inverse is only defined for square matrices");
        }
        DMatrix identity = identity(numRows());
        DMatrix solve = LinearEquationsSolver.solve(this, identity, new DMatrix(numRows(), numColumns()));
        if (approximatelyEquals(mul(solve), identity, 1.5d * Math.sqrt(1.1102230246251565E-16d))) {
            return solve;
        }
        throw new RuntimeException("Matrix A may be (close to) singular.");
    }

    public DMatrix add(DMatrix dMatrix) {
        checkEqualDimension(this, dMatrix);
        return add(dMatrix, new DMatrix(this.rows, this.cols));
    }

    public DMatrix addInplace(DMatrix dMatrix) {
        checkEqualDimension(this, dMatrix);
        return add(dMatrix, this);
    }

    private DMatrix add(DMatrix dMatrix, DMatrix dMatrix2) {
        double[] dArr = this.a;
        double[] dArr2 = dMatrix.a;
        double[] dArr3 = dMatrix2.a;
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = dArr[i] + dArr2[i];
        }
        return dMatrix2;
    }

    public DMatrix addBroadcastedVector(DMatrix dMatrix) {
        checkSameRows(this, dMatrix);
        return addBroadcastedVector(dMatrix, new DMatrix(this.rows, this.cols));
    }

    public DMatrix addBroadcastedVectorInplace(DMatrix dMatrix) {
        checkSameRows(this, dMatrix);
        return addBroadcastedVector(dMatrix, this);
    }

    private DMatrix addBroadcastedVector(DMatrix dMatrix, DMatrix dMatrix2) {
        if (this.cols == dMatrix.cols) {
            return add(dMatrix, dMatrix2);
        }
        if (dMatrix.numColumns() != 1) {
            throw getSameColsException(this, dMatrix);
        }
        double[] dArr = this.a;
        double[] dArr2 = dMatrix.a;
        double[] dArr3 = dMatrix2.a;
        int i = this.cols;
        int i2 = this.rows;
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                dArr3[idx(i4, i3)] = dArr[idx(i4, i3)] + dArr2[i4];
            }
        }
        return dMatrix2;
    }

    public DMatrix minus(DMatrix dMatrix) {
        checkEqualDimension(this, dMatrix);
        DMatrix dMatrix2 = new DMatrix(this.rows, this.cols);
        double[] dArr = this.a;
        double[] dArr2 = dMatrix.a;
        double[] dArr3 = dMatrix2.a;
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = dArr[i] - dArr2[i];
        }
        return dMatrix2;
    }

    public DMatrix mul(DMatrix dMatrix) {
        checkMul(this, dMatrix);
        DMatrix dMatrix2 = new DMatrix(this.rows, dMatrix.cols);
        Dgemm.dgemm(Trans.NO_TRANS, Trans.NO_TRANS, dMatrix2.rows, dMatrix2.cols, this.cols, 1.0d, this.a, 0, this.rows, dMatrix.a, 0, dMatrix.rows, 0.0d, dMatrix2.a, 0, dMatrix2.rows);
        return dMatrix2;
    }

    public String toString() {
        return toString(this);
    }

    public static boolean approximatelyEquals(DMatrix dMatrix, DMatrix dMatrix2, double d) {
        return approximatelyEquals(dMatrix, dMatrix2, 1.0E-8d, d);
    }

    private static boolean approximatelyEquals(DMatrix dMatrix, DMatrix dMatrix2, double d, double d2) {
        if (dMatrix.numRows() != dMatrix2.numRows() || dMatrix.numColumns() != dMatrix2.numColumns()) {
            return false;
        }
        if (d2 < 0.0d || Double.isNaN(d2) || Double.isInfinite(d2)) {
            throw new IllegalArgumentException("illegal absTol : " + d2);
        }
        if (dMatrix == dMatrix2) {
            return true;
        }
        double[] arrayUnsafe = dMatrix.getArrayUnsafe();
        double[] arrayUnsafe2 = dMatrix2.getArrayUnsafe();
        for (int i = 0; i < arrayUnsafe.length; i++) {
            double d3 = arrayUnsafe[i];
            double d4 = arrayUnsafe2[i];
            if (d3 != d4) {
                double abs = Math.abs(d3 - d4);
                if (abs > d * Math.max(Math.abs(d3), Math.abs(d4)) && abs > d2) {
                    return false;
                }
            }
        }
        return true;
    }

    protected final int idx(int i, int i2) {
        return (i2 * this.rows) + i;
    }

    protected void checkIndex(int i, int i2) {
        if (i < 0 || i >= this.rows) {
            throw new IllegalArgumentException("Illegal row index " + i + " in (" + this.rows + " x " + this.cols + ") matrix");
        }
        if (i2 < 0 || i2 >= this.cols) {
            throw new IllegalArgumentException("Illegal column index " + i2 + " in (" + this.rows + " x " + this.cols + ") matrix");
        }
    }

    protected static String toString(DMatrix dMatrix) {
        StringBuilder sb = new StringBuilder();
        sb.append("(").append(dMatrix.rows).append(" x ").append(dMatrix.cols).append(")").append(System.lineSeparator());
        int numColumns = dMatrix.numColumns() <= 6 ? dMatrix.numColumns() : 5;
        int numRows = dMatrix.numRows() <= 6 ? dMatrix.numRows() : 5;
        int i = 0;
        while (i < numRows) {
            printRowD(i, numColumns, dMatrix, sb);
            i++;
        }
        if (i == 5 && numRows < dMatrix.numRows()) {
            int numColumns2 = numColumns < dMatrix.numColumns() ? 6 : dMatrix.numColumns();
            for (int i2 = 0; i2 < numColumns2; i2++) {
                sb.append("......");
                if (i2 != numColumns2 - 1) {
                    sb.append(", ");
                }
            }
            sb.append(System.lineSeparator());
            printRowD(dMatrix.numRows() - 1, numColumns, dMatrix, sb);
        }
        return sb.toString();
    }

    private static void printRowD(int i, int i2, DMatrix dMatrix, StringBuilder sb) {
        int i3 = 0;
        while (i3 < i2) {
            sb.append(String.format("%.12E", Double.valueOf(dMatrix.getUnsafe(i, i3))));
            if (i3 < i2 - 1) {
                sb.append(", ");
            }
            i3++;
        }
        if (i3 == 5 && i2 < dMatrix.numColumns()) {
            sb.append(", ......, ");
            sb.append(String.format("%.12E", Double.valueOf(dMatrix.getUnsafe(i, dMatrix.numColumns() - 1))));
        }
        sb.append(System.lineSeparator());
    }

    protected static void checkSameRows(DMatrix dMatrix, DMatrix dMatrix2) {
        if (dMatrix.numRows() != dMatrix2.numRows()) {
            throw new IndexOutOfBoundsException("A.numRows() != B.numRows() (" + dMatrix.numRows() + " != " + dMatrix2.numRows() + ")");
        }
    }

    protected static void checkSameCols(DMatrix dMatrix, DMatrix dMatrix2) {
        if (dMatrix.numColumns() != dMatrix2.numColumns()) {
            throw getSameColsException(dMatrix, dMatrix2);
        }
    }

    protected static void checkEqualDimension(DMatrix dMatrix, DMatrix dMatrix2) {
        checkSameRows(dMatrix, dMatrix2);
        checkSameCols(dMatrix, dMatrix2);
    }

    protected static void checkMul(DMatrix dMatrix, DMatrix dMatrix2) {
        if (dMatrix.numColumns() != dMatrix2.numRows()) {
            throw new IndexOutOfBoundsException("A.numColumns() != B.numRows() (" + dMatrix.numColumns() + " != " + dMatrix2.numRows() + ")");
        }
    }

    protected static void checkMul(DMatrix dMatrix, DMatrix dMatrix2, DMatrix dMatrix3) {
        if (dMatrix.numRows() != dMatrix3.numRows()) {
            throw new IndexOutOfBoundsException("A.numRows() != C.numRows() (" + dMatrix.numRows() + " != " + dMatrix3.numRows() + ")");
        }
        if (dMatrix.numColumns() != dMatrix2.numRows()) {
            throw new IndexOutOfBoundsException("A.numColumns() != B.numRows() (" + dMatrix.numColumns() + " != " + dMatrix2.numRows() + ")");
        }
        if (dMatrix2.numColumns() != dMatrix3.numColumns()) {
            throw new IndexOutOfBoundsException("B.numColumns() != C.numColumns() (" + dMatrix2.numColumns() + " != " + dMatrix3.numColumns() + ")");
        }
    }

    protected static void checkAdd(DMatrix dMatrix, DMatrix dMatrix2, DMatrix dMatrix3) {
        checkEqualDimension(dMatrix, dMatrix2);
        if (dMatrix2.numRows() != dMatrix3.numRows()) {
            throw new IndexOutOfBoundsException("B.numRows() != C.numRows() (" + dMatrix2.numRows() + " != " + dMatrix3.numRows() + ")");
        }
        if (dMatrix2.numColumns() != dMatrix3.numColumns()) {
            throw new IndexOutOfBoundsException("B.numColumns() != C.numColumns() (" + dMatrix2.numColumns() + " != " + dMatrix3.numColumns() + ")");
        }
    }

    protected static IndexOutOfBoundsException getSameColsException(DMatrix dMatrix, DMatrix dMatrix2) {
        return new IndexOutOfBoundsException("A.numColumns() != B.numColumns() (" + dMatrix.numColumns() + " != " + dMatrix2.numColumns() + ")");
    }

    protected static int checkArrayLength(int i, int i2) {
        long checkRows = checkRows(i) * checkCols(i2);
        if (checkRows > 2147483647L) {
            throw new IllegalArgumentException("rows x cols (= " + checkRows + ") exceeds the maximal possible length (= 2147483647) of an array");
        }
        return (int) checkRows;
    }

    protected static int checkRows(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("number of rows must be strictly positive : " + i);
        }
        return i;
    }

    protected static int checkCols(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("number of columns must be strictly positive : " + i);
        }
        return i;
    }
}
