package org.linqs.psl.util;

import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.LAPACK;
import java.util.Arrays;
import org.netlib.util.intW;

/* loaded from: input_file:org/linqs/psl/util/FloatMatrix.class */
public final class FloatMatrix {
    private float[] data;
    private int numRows;
    private int numCols;
    static final /* synthetic */ boolean $assertionsDisabled;

    public FloatMatrix() {
        this.data = null;
        this.numRows = 0;
        this.numCols = 0;
    }

    public FloatMatrix(float[] fArr, int i, int i2) {
        this(fArr, i, i2, true);
    }

    public FloatMatrix(float[] fArr, int i, int i2, boolean z) {
        if (fArr.length != i * i2) {
            throw new IllegalArgumentException(String.format("Length of data (%d) and size of matrix (%d x %d = %d) does not match.", Integer.valueOf(fArr.length), Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i * i2)));
        }
        this.numRows = i;
        this.numCols = i2;
        if (z) {
            this.data = Arrays.copyOf(fArr, fArr.length);
        } else {
            this.data = fArr;
        }
    }

    public FloatMatrix(float[][] fArr) {
        if (fArr == null || fArr.length == 0) {
            this.data = null;
            this.numRows = 0;
            this.numCols = 0;
            return;
        }
        this.numRows = fArr.length;
        this.numCols = fArr[0].length;
        this.data = new float[this.numRows * this.numCols];
        for (int i = 0; i < this.numRows; i++) {
            if (fArr[i].length != this.numCols) {
                throw new IllegalArgumentException(String.format("Matrix does not have consistent number of columns. Expecting %d, found %d (row %d).", Integer.valueOf(this.numCols), Integer.valueOf(fArr[i].length), Integer.valueOf(i)));
            }
            for (int i2 = 0; i2 < this.numCols; i2++) {
                this.data[(i2 * this.numRows) + i] = fArr[i][i2];
            }
        }
    }

    public void assume(float[] fArr, int i, int i2) {
        this.data = fArr;
        this.numRows = i;
        this.numCols = i2;
    }

    public static FloatMatrix zeroes(int i, int i2) {
        return new FloatMatrix(new float[i * i2], i, i2);
    }

    public static FloatMatrix ones(int i, int i2) {
        float[] fArr = new float[i * i2];
        for (int i3 = 0; i3 < i * i2; i3++) {
            fArr[i3] = 1.0f;
        }
        return new FloatMatrix(fArr, i, i2, false);
    }

    public static FloatMatrix eye(int i) {
        float[] fArr = new float[i * i];
        for (int i2 = 0; i2 < i; i2++) {
            fArr[i2 * (i + 1)] = 1.0f;
        }
        return new FloatMatrix(fArr, i, i, false);
    }

    public static FloatMatrix columnVector(float[] fArr) {
        return columnVector(fArr, true);
    }

    public static FloatMatrix columnVector(float[] fArr, boolean z) {
        return new FloatMatrix(fArr, fArr.length, 1, z);
    }

    public static FloatMatrix rowVector(float[] fArr) {
        return rowVector(fArr, true);
    }

    public static FloatMatrix rowVector(float[] fArr, boolean z) {
        return new FloatMatrix(fArr, 1, fArr.length, z);
    }

    public FloatMatrix copy() {
        return new FloatMatrix(Arrays.copyOf(this.data, this.data.length), this.numRows, this.numCols, false);
    }

    public float[][] asGrid() {
        float[][] fArr = new float[this.numRows][this.numCols];
        for (int i = 0; i < this.numRows; i++) {
            for (int i2 = 0; i2 < this.numCols; i2++) {
                fArr[i][i2] = this.data[(i2 * this.numRows) + i];
            }
        }
        return fArr;
    }

    public float[] asColumnArray() {
        return Arrays.copyOf(this.data, this.data.length);
    }

    public float[] asRowArray() {
        float[] fArr = new float[size()];
        for (int i = 0; i < this.numRows; i++) {
            for (int i2 = 0; i2 < this.numCols; i2++) {
                fArr[(i * this.numCols) + i2] = this.data[(i2 * this.numRows) + i];
            }
        }
        return fArr;
    }

    public float get(int i, int i2) {
        return this.data[(i2 * this.numRows) + i];
    }

    public void set(int i, int i2, float f) {
        this.data[(i2 * this.numRows) + i] = f;
    }

    public int size() {
        return this.numRows * this.numCols;
    }

    public int numRows() {
        return this.numRows;
    }

    public int numCols() {
        return this.numCols;
    }

    public int hashCode() {
        return HashCode.build(HashCode.build(Integer.valueOf(this.numRows)), Integer.valueOf(this.numCols));
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || !(obj instanceof FloatMatrix)) {
            return false;
        }
        FloatMatrix floatMatrix = (FloatMatrix) obj;
        if (this.numRows != floatMatrix.numRows || this.numCols != floatMatrix.numCols) {
            return false;
        }
        for (int i = 0; i < size(); i++) {
            if (!MathUtils.equals(this.data[i], floatMatrix.data[i])) {
                return false;
            }
        }
        return true;
    }

    public String toString() {
        if (this.data == null || this.numRows == 0) {
            return "[]";
        }
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        for (int i = 0; i < this.numRows; i++) {
            sb.append("[");
            for (int i2 = 0; i2 < this.numCols; i2++) {
                sb.append(get(i, i2));
                if (i2 != this.numCols - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            if (i != this.numRows - 1) {
                sb.append(", ");
            }
        }
        sb.append("]");
        return sb.toString();
    }

    public FloatMatrix elementSub(FloatMatrix floatMatrix) {
        return elementSub(floatMatrix, false);
    }

    public FloatMatrix elementSub(FloatMatrix floatMatrix, boolean z) {
        if (!$assertionsDisabled && this.numRows != floatMatrix.numRows) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.numCols != floatMatrix.numCols) {
            throw new AssertionError();
        }
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = this.data[i] - floatMatrix.data[i];
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public FloatMatrix elementAdd(FloatMatrix floatMatrix) {
        return elementAdd(floatMatrix, false);
    }

    public FloatMatrix elementAdd(FloatMatrix floatMatrix, boolean z) {
        if (!$assertionsDisabled && this.numRows != floatMatrix.numRows) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.numCols != floatMatrix.numCols) {
            throw new AssertionError();
        }
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = this.data[i] + floatMatrix.data[i];
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public FloatMatrix elementMul(FloatMatrix floatMatrix) {
        return elementMul(floatMatrix, false);
    }

    public FloatMatrix elementMul(FloatMatrix floatMatrix, boolean z) {
        if (!$assertionsDisabled && this.numRows != floatMatrix.numRows) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.numCols != floatMatrix.numCols) {
            throw new AssertionError();
        }
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = this.data[i] * floatMatrix.data[i];
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public FloatMatrix elementDiv(FloatMatrix floatMatrix) {
        return elementDiv(floatMatrix, false);
    }

    public FloatMatrix elementDiv(FloatMatrix floatMatrix, boolean z) {
        if (!$assertionsDisabled && this.numRows != floatMatrix.numRows) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.numCols != floatMatrix.numCols) {
            throw new AssertionError();
        }
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = this.data[i] / floatMatrix.data[i];
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public FloatMatrix elementLog() {
        return elementLog(false);
    }

    public FloatMatrix elementLog(boolean z) {
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = (float) Math.log(this.data[i]);
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public FloatMatrix sub(float f) {
        return sub(f, false);
    }

    public FloatMatrix sub(float f, boolean z) {
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = this.data[i] - f;
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public FloatMatrix add(float f) {
        return add(f, false);
    }

    public FloatMatrix add(float f, boolean z) {
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = this.data[i] + f;
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public FloatMatrix mul(float f) {
        return mul(f, false);
    }

    public FloatMatrix mul(float f, boolean z) {
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = this.data[i] * f;
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public FloatMatrix div(float f) {
        return div(f, false);
    }

    public FloatMatrix div(float f, boolean z) {
        float[] fArr = z ? this.data : new float[size()];
        for (int i = 0; i < size(); i++) {
            fArr[i] = this.data[i] / f;
        }
        return z ? this : new FloatMatrix(fArr, this.numRows, this.numCols, false);
    }

    public float norm1() {
        float f = 0.0f;
        for (int i = 0; i < size(); i++) {
            f += Math.abs(this.data[i]);
        }
        return f;
    }

    public float norm2() {
        float f = 0.0f;
        for (int i = 0; i < size(); i++) {
            f = (float) (f + Math.pow(this.data[i], 2.0d));
        }
        return (float) Math.sqrt(f);
    }

    public FloatMatrix transpose() {
        FloatMatrix zeroes = zeroes(this.numCols, this.numRows);
        for (int i = 0; i < this.numRows; i++) {
            for (int i2 = 0; i2 < this.numCols; i2++) {
                zeroes.set(i2, i, get(i, i2));
            }
        }
        return zeroes;
    }

    public FloatMatrix mul(FloatMatrix floatMatrix) {
        return mul(floatMatrix, false, false, 1.0f);
    }

    public FloatMatrix mul(FloatMatrix floatMatrix, boolean z, boolean z2, float f) {
        return mul(floatMatrix, null, z, z2, f, 0.0f);
    }

    public FloatMatrix mul(FloatMatrix floatMatrix, FloatMatrix floatMatrix2, boolean z, boolean z2, float f, float f2) {
        return blas_sgemm(floatMatrix, floatMatrix2, z, z2, f, f2);
    }

    public float dot(FloatMatrix floatMatrix) {
        return blas_sdot(floatMatrix);
    }

    public FloatMatrix inverse() {
        if (this.numRows != this.numCols) {
            throw new IllegalArgumentException(String.format("Cannot invert a non-square matrix (%d x %d).", Integer.valueOf(this.numRows), Integer.valueOf(this.numCols)));
        }
        FloatMatrix copy = copy();
        FloatMatrix eye = eye(this.numRows);
        try {
            copy.lapack_sgesv(eye);
            return eye;
        } catch (ArithmeticException e) {
            throw new ArithmeticException("Non-invertible matrix: " + toString());
        }
    }

    public FloatMatrix choleskyDecomposition() {
        return choleskyDecomposition(false);
    }

    public FloatMatrix choleskyDecomposition(boolean z) {
        FloatMatrix floatMatrix = this;
        if (!z) {
            floatMatrix = copy();
        }
        floatMatrix.lapack_spotrf(false);
        return floatMatrix;
    }

    public int[] lapack_sgesv(FloatMatrix floatMatrix) {
        if (this.numRows != this.numCols) {
            throw new IllegalArgumentException(String.format("sgesv requires a square A matrix, got (%d x %d).", Integer.valueOf(this.numRows), Integer.valueOf(this.numCols)));
        }
        int[] iArr = new int[this.numRows];
        intW intw = new intW(0);
        LAPACK.getInstance().sgesv(this.numRows, floatMatrix.numCols, this.data, this.numRows, iArr, floatMatrix.data, floatMatrix.numRows, intw);
        if (intw.val < 0) {
            throw new IllegalArgumentException(String.format("Error in the %d argument to sgesv.", Integer.valueOf(intw.val * (-1))));
        }
        if (intw.val > 0) {
            throw new ArithmeticException(String.format("Error in sgesv. U(%d, %d) is singular, so the solution could not be computed.", Integer.valueOf(intw.val), Integer.valueOf(intw.val)));
        }
        return iArr;
    }

    public void lapack_spotrf(boolean z) {
        if (this.numRows != this.numCols) {
            throw new IllegalArgumentException(String.format("spotrf requires a square A matrix, got (%d x %d).", Integer.valueOf(this.numRows), Integer.valueOf(this.numCols)));
        }
        String str = z ? "U" : "L";
        intW intw = new intW(0);
        LAPACK.getInstance().spotrf(str, this.numRows, this.data, this.numRows, intw);
        if (intw.val < 0) {
            throw new IllegalArgumentException(String.format("Error in the %d argument to spotrf.", Integer.valueOf(intw.val * (-1))));
        }
        if (intw.val > 0) {
            throw new ArithmeticException(String.format("Error in spotrf (%d). Matrix is not positive definite.", Integer.valueOf(intw.val)));
        }
    }

    public FloatMatrix blas_sgemm(FloatMatrix floatMatrix, FloatMatrix floatMatrix2, boolean z, boolean z2, float f, float f2) {
        String str = "N";
        int i = this.numRows;
        int i2 = this.numCols;
        String str2 = "N";
        int i3 = floatMatrix.numRows;
        int i4 = floatMatrix.numCols;
        if (z) {
            str = "T";
            i = this.numCols;
            i2 = this.numRows;
        }
        if (z2) {
            str2 = "T";
            i3 = floatMatrix.numCols;
            i4 = floatMatrix.numRows;
        }
        if (i2 != i3) {
            throw new IllegalArgumentException(String.format("Cannot multiply matrices of (post transposed) dimensions (%d x %d) and (%d x %d).", Integer.valueOf(this.numRows), Integer.valueOf(this.numCols), Integer.valueOf(floatMatrix.numRows), Integer.valueOf(floatMatrix.numCols)));
        }
        if (floatMatrix2 == null) {
            floatMatrix2 = zeroes(i, i4);
            f2 = 0.0f;
        }
        BLAS.getInstance().sgemm(str, str2, i, i4, i2, f, this.data, i, floatMatrix.data, i3, f2, floatMatrix2.data, floatMatrix2.numRows);
        return floatMatrix2;
    }

    public float blas_sdot(FloatMatrix floatMatrix) {
        if ((this.numRows != 1 && this.numCols != 1) || (floatMatrix.numRows != 1 && floatMatrix.numCols != 1)) {
            throw new IllegalArgumentException(String.format("sdot only works with vectors. Got (%d x %d) and (%d x %d).", Integer.valueOf(this.numRows), Integer.valueOf(this.numCols), Integer.valueOf(floatMatrix.numRows), Integer.valueOf(floatMatrix.numCols)));
        }
        if (size() != floatMatrix.size()) {
            throw new IllegalArgumentException(String.format("sdot only works with same sized vectors. Got %d and %d.", Integer.valueOf(size()), Integer.valueOf(floatMatrix.size())));
        }
        return BLAS.getInstance().sdot(size(), this.data, 1, floatMatrix.data, 1);
    }

    static {
        $assertionsDisabled = !FloatMatrix.class.desiredAssertionStatus();
    }
}
