package com.gengoai.apollo.math.linalg;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.gengoai.Copyable;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.concurrent.AtomicDouble;
import com.gengoai.conversion.Cast;
import com.gengoai.math.Optimum;
import java.util.function.DoubleUnaryOperator;
import lombok.NonNull;
import org.apache.mahout.math.list.FloatArrayList;
import org.apache.mahout.math.list.IntArrayList;
import org.apache.mahout.math.map.OpenIntFloatHashMap;
import org.jblas.DoubleMatrix;
import org.jblas.FloatMatrix;

/* loaded from: input_file:com/gengoai/apollo/math/linalg/SparseMatrix.class */
public class SparseMatrix extends Matrix {
    private static final long serialVersionUID = 1;
    private final OpenIntFloatHashMap map;

    public SparseMatrix(@NonNull int... iArr) {
        this(new Shape(iArr));
        if (iArr == null) {
            throw new NullPointerException("dims is marked non-null but is null");
        }
    }

    public SparseMatrix(@NonNull Shape shape) {
        super(shape);
        if (shape == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        this.map = new OpenIntFloatHashMap();
    }

    protected SparseMatrix(@NonNull SparseMatrix sparseMatrix) {
        super(sparseMatrix.shape);
        if (sparseMatrix == null) {
            throw new NullPointerException("toCopy is marked non-null but is null");
        }
        this.map = (OpenIntFloatHashMap) Copyable.deepCopy(sparseMatrix.map);
    }

    @JsonCreator
    protected SparseMatrix(@JsonProperty("indices") int[] iArr, @JsonProperty("values") float[] fArr, @JsonProperty("shape") Shape shape, @JsonProperty("label") Object obj, @JsonProperty("predicted") Object obj2, @JsonProperty("weight") double d) {
        this(shape);
        setLabel(obj);
        setPredicted(obj2);
        setWeight(d);
        for (int i = 0; i < iArr.length; i++) {
            set(iArr[i], fArr[i]);
        }
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray T() {
        SparseMatrix sparseMatrix;
        if (this.shape.isVector()) {
            sparseMatrix = new SparseMatrix(this);
            sparseMatrix.shape.reshape(this.shape.columns(), this.shape.rows());
        } else {
            sparseMatrix = new SparseMatrix(this.shape.columns(), this.shape.rows());
            this.map.forEachPair((i, f) -> {
                sparseMatrix.set(this.shape.toColumn(i), this.shape.toRow(i), f);
                return true;
            });
        }
        return sparseMatrix;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray add(@NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("rhs is marked non-null but is null");
        }
        return m1copy().addi(nDArray);
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray addi(@NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("rhs is marked non-null but is null");
        }
        if (nDArray.isDense()) {
            return super.addi(nDArray);
        }
        checkLength(nDArray.shape);
        ((SparseMatrix) Cast.as(nDArray)).map.forEachPair((i, f) -> {
            this.map.adjustOrPutValue(i, f, f);
            return true;
        });
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray compact() {
        this.map.trimToSize();
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray div(@NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("rhs is marked non-null but is null");
        }
        return m1copy().divi(nDArray);
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public double dot(@NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("rhs is marked non-null but is null");
        }
        checkLength(nDArray.shape());
        AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        this.map.forEachPair((i, f) -> {
            atomicDouble.addAndGet(nDArray.get(i) * f);
            return true;
        });
        return atomicDouble.get();
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public void forEachSparse(@NonNull NDArray.EntryConsumer entryConsumer) {
        if (entryConsumer == null) {
            throw new NullPointerException("consumer is marked non-null but is null");
        }
        this.map.forEachPair((i, f) -> {
            entryConsumer.apply(i, f);
            return true;
        });
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public double get(long j) {
        return this.map.get((int) j);
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public double get(int i, int i2) {
        return this.map.get(this.shape.matrixIndex(i, i2));
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray getColumn(int i) {
        SparseMatrix sparseMatrix = new SparseMatrix(this.shape.rows(), 1);
        for (int i2 = 0; i2 < this.shape.rows(); i2++) {
            sparseMatrix.set(i2, get(i2, i));
        }
        return sparseMatrix;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray getRow(int i) {
        SparseMatrix sparseMatrix = new SparseMatrix(1, this.shape.columns());
        for (int i2 = 0; i2 < this.shape.columns(); i2++) {
            sparseMatrix.set(i, i2, get(i, i2));
        }
        return sparseMatrix;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray getSubMatrix(int i, int i2, int i3, int i4) {
        SparseMatrix sparseMatrix = new SparseMatrix(i2 - i, i4 - i3);
        this.map.forEachPair((i5, f) -> {
            int row = this.shape.toRow(i5);
            int column = this.shape.toColumn(i5);
            if (row < i || row >= i2 || column < i3 || column >= i4) {
                return true;
            }
            sparseMatrix.set(row - i, column - i3, f);
            return true;
        });
        return sparseMatrix;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public boolean isDense() {
        return false;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray map(@NonNull DoubleUnaryOperator doubleUnaryOperator) {
        if (doubleUnaryOperator == null) {
            throw new NullPointerException("operator is marked non-null but is null");
        }
        NDArray zeroLike = zeroLike();
        for (int i = 0; i < this.shape.matrixLength; i++) {
            zeroLike.set(i, doubleUnaryOperator.applyAsDouble(get(i)));
        }
        return zeroLike;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray mapi(@NonNull DoubleUnaryOperator doubleUnaryOperator) {
        if (doubleUnaryOperator == null) {
            throw new NullPointerException("operator is marked non-null but is null");
        }
        for (int i = 0; i < this.shape.matrixLength; i++) {
            set(i, doubleUnaryOperator.applyAsDouble(get(i)));
        }
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public double max() {
        double optimumValue = Optimum.MAXIMUM.optimumValue(this.map.values().elements());
        return this.map.size() == this.shape.matrixLength ? optimumValue : Math.max(0.0d, optimumValue);
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public double min() {
        double optimumValue = Optimum.MINIMUM.optimumValue(this.map.values().elements());
        return this.map.size() == this.shape.matrixLength ? optimumValue : Math.min(0.0d, optimumValue);
    }

    @Override // com.gengoai.apollo.math.linalg.Matrix, com.gengoai.apollo.math.linalg.NDArray
    public NDArray mmul(@NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("rhs is marked non-null but is null");
        }
        if (nDArray.isDense() || (sparsity() < 0.5d && length() > 10000)) {
            return super.mmul(nDArray);
        }
        SparseMatrix sparseMatrix = new SparseMatrix(rows(), nDArray.columns());
        this.map.forEachPair((i, f) -> {
            int row = this.shape.toRow(i);
            int column = this.shape.toColumn(i);
            for (int i = 0; i < nDArray.columns(); i++) {
                double d = f * nDArray.get(column, i);
                sparseMatrix.map.adjustOrPutValue(sparseMatrix.shape.matrixIndex(row, i), (float) d, (float) d);
            }
            return true;
        });
        return sparseMatrix;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray mul(@NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("rhs is marked non-null but is null");
        }
        return m1copy().muli(nDArray);
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray muli(@NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("rhs is marked non-null but is null");
        }
        checkLength(nDArray.shape());
        this.map.forEachPair((i, f) -> {
            this.map.put(i, f * ((float) nDArray.get(i)));
            return true;
        });
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.Matrix, com.gengoai.apollo.math.linalg.NDArray
    public double norm1() {
        double d = 0.0d;
        int length = this.map.values().elements().length;
        for (int i = 0; i < length; i++) {
            d += Math.abs(r0[i]);
        }
        return d;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray reshape(int... iArr) {
        this.shape.reshape(iArr);
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray set(long j, double d) {
        if (d == 0.0d) {
            this.map.removeKey((int) j);
        } else {
            this.map.put((int) j, (float) d);
        }
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray set(int i, int i2, double d) {
        this.map.put(this.shape.matrixIndex(i, i2), (float) d);
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray setColumn(int i, @NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        checkLength(this.shape.rows(), nDArray.shape());
        for (int i2 = 0; i2 < nDArray.shape().matrixLength; i2++) {
            set(i2, i, nDArray.get(i2));
        }
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray setRow(int i, @NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        checkLength(this.shape.columns(), nDArray.shape());
        for (int i2 = 0; i2 < nDArray.shape().matrixLength; i2++) {
            set(i, i2, nDArray.get(i2));
        }
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public long size() {
        return this.map.size();
    }

    @Override // com.gengoai.apollo.math.linalg.Matrix, com.gengoai.apollo.math.linalg.NDArray
    @JsonProperty("indices")
    public int[] sparseIndices() {
        IntArrayList keys = this.map.keys();
        keys.sort();
        return keys.toArray(new int[0]);
    }

    @JsonProperty("values")
    private float[] sparseValues() {
        FloatArrayList values = this.map.values();
        values.trimToSize();
        return values.elements();
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray sub(NDArray nDArray) {
        return m1copy().subi(nDArray);
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray subi(@NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("rhs is marked non-null but is null");
        }
        if (nDArray.isDense()) {
            return super.addi(nDArray);
        }
        checkLength(nDArray.shape());
        ((SparseMatrix) Cast.as(nDArray)).map.forEachPair((i, f) -> {
            this.map.adjustOrPutValue(i, -f, -f);
            return true;
        });
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public double sum() {
        double d = 0.0d;
        for (int i = 0; i < this.map.values().elements().length; i++) {
            d += r0[i];
        }
        return d;
    }

    @Override // com.gengoai.apollo.math.linalg.Matrix, com.gengoai.apollo.math.linalg.NDArray
    public double sumOfSquares() {
        double d = 0.0d;
        for (double d2 : this.map.values().elements()) {
            d += d2 * d2;
        }
        return d;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public double[] toDoubleArray() {
        double[] dArr = new double[(int) length()];
        this.map.forEachPair((i, f) -> {
            dArr[i] = f;
            return true;
        });
        return dArr;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public DoubleMatrix[] toDoubleMatrix() {
        DoubleMatrix doubleMatrix = new DoubleMatrix(this.shape.rows(), this.shape.columns());
        this.map.forEachPair((i, f) -> {
            doubleMatrix.data[i] = f;
            return true;
        });
        return new DoubleMatrix[]{doubleMatrix};
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public float[] toFloatArray() {
        float[] fArr = new float[(int) length()];
        this.map.forEachPair((i, f) -> {
            fArr[i] = f;
            return true;
        });
        return fArr;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public float[][] toFloatArray2() {
        float[][] fArr = new float[rows()][columns()];
        this.map.forEachPair((i, f) -> {
            int row = this.shape.toRow(i);
            fArr[row][this.shape.toColumn(i)] = f;
            return true;
        });
        return fArr;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [float[][], float[][][]] */
    @Override // com.gengoai.apollo.math.linalg.NDArray
    public float[][][] toFloatArray3() {
        return new float[][]{toFloatArray2()};
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public FloatMatrix[] toFloatMatrix() {
        FloatMatrix floatMatrix = new FloatMatrix(this.shape.rows(), this.shape.columns());
        this.map.forEachPair((i, f) -> {
            floatMatrix.data[i] = f;
            return true;
        });
        return new FloatMatrix[]{floatMatrix};
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray zero() {
        this.map.clear();
        this.map.trimToSize();
        return this;
    }

    @Override // com.gengoai.apollo.math.linalg.NDArray
    public NDArray zeroLike() {
        return new SparseMatrix(this.shape);
    }
}
