package com.gengoai.apollo.math.linalg;

import com.gengoai.Validation;
import com.gengoai.config.Config;
import java.util.Arrays;
import java.util.Collection;
import java.util.Random;
import lombok.NonNull;
import org.jblas.DoubleMatrix;
import org.jblas.FloatMatrix;

/* loaded from: input_file:com/gengoai/apollo/math/linalg/NDArrayFactory.class */
public enum NDArrayFactory {
    ND { // from class: com.gengoai.apollo.math.linalg.NDArrayFactory.1
        private volatile NDArrayFactory factory;

        private NDArrayFactory getFactory() {
            if (this.factory == null) {
                synchronized (ND) {
                    if (this.factory == null) {
                        this.factory = (NDArrayFactory) Config.get("NDArrayFactory.default", new Object[0]).as(NDArrayFactory.class, DENSE);
                    }
                }
            }
            return this.factory;
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray columnVector(double[] dArr) {
            return getFactory().columnVector(dArr);
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray rowVector(double[] dArr) {
            return getFactory().rowVector(dArr);
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(Shape shape) {
            return getFactory().array(shape);
        }
    },
    DENSE { // from class: com.gengoai.apollo.math.linalg.NDArrayFactory.2
        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(@NonNull Shape shape) {
            if (shape == null) {
                throw new NullPointerException("shape is marked non-null but is null");
            }
            if (!shape.isTensor()) {
                return new DenseMatrix(shape);
            }
            Tensor tensor = new Tensor(shape);
            for (int i = 0; i < shape.sliceLength; i++) {
                tensor.slices[i] = new DenseMatrix(shape.rows(), shape.columns());
            }
            return tensor;
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(double[] dArr) {
            return new DenseMatrix(new DoubleMatrix(dArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(float[] fArr) {
            return new DenseMatrix(new FloatMatrix(fArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(int i, int i2, double[] dArr) {
            return new DenseMatrix(new DoubleMatrix(i, i2, dArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(int i, int i2, float[] fArr) {
            return new DenseMatrix(new FloatMatrix(i, i2, fArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(double[][] dArr) {
            return new DenseMatrix(new DoubleMatrix(dArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(float[][] fArr) {
            return new DenseMatrix(new FloatMatrix(fArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray columnVector(double[] dArr) {
            return new DenseMatrix(new DoubleMatrix(dArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray rowVector(double[] dArr) {
            return new DenseMatrix(new DoubleMatrix(1, dArr.length, dArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray columnVector(float[] fArr) {
            return new DenseMatrix(new FloatMatrix(fArr));
        }

        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray rowVector(float[] fArr) {
            return new DenseMatrix(new FloatMatrix(1, fArr.length, fArr));
        }
    },
    SPARSE { // from class: com.gengoai.apollo.math.linalg.NDArrayFactory.3
        @Override // com.gengoai.apollo.math.linalg.NDArrayFactory
        public NDArray array(Shape shape) {
            if (!shape.isTensor()) {
                return new SparseMatrix(shape);
            }
            Tensor tensor = new Tensor(shape);
            for (int i = 0; i < shape.sliceLength; i++) {
                tensor.slices[i] = new SparseMatrix(shape.rows(), shape.columns());
            }
            return tensor;
        }
    };

    public NDArray array(NDArray[] nDArrayArr) {
        return array(0, nDArrayArr.length, nDArrayArr);
    }

    public NDArray array(int i, int i2, NDArray[] nDArrayArr) {
        Validation.checkArgument(i * i2 == nDArrayArr.length, () -> {
            return "Invalid Slice Length: " + (i * i2) + " != " + nDArrayArr.length;
        });
        return new Tensor(i, i2, nDArrayArr);
    }

    public NDArray array(int... iArr) {
        return array(new Shape(iArr));
    }

    public NDArray array(double[] dArr) {
        return columnVector(dArr);
    }

    public NDArray array(float[] fArr) {
        return columnVector(fArr);
    }

    public NDArray array(int i, int i2, double[] dArr) {
        Validation.checkArgument(i * i2 == dArr.length, () -> {
            return "Invalid Length: " + (i * i2) + " != " + dArr.length;
        });
        NDArray array = array(i, i2);
        for (int i3 = 0; i3 < dArr.length; i3++) {
            array.set(i3, dArr[i3]);
        }
        return array;
    }

    public NDArray array(int i, int i2, float[] fArr) {
        Validation.checkArgument(i * i2 == fArr.length, () -> {
            return "Invalid Length: " + (i * i2) + " != " + fArr.length;
        });
        NDArray array = array(i, i2);
        for (int i3 = 0; i3 < fArr.length; i3++) {
            array.set(i3, fArr[i3]);
        }
        return array;
    }

    public NDArray array(double[][] dArr) {
        if (dArr.length == 0) {
            return empty();
        }
        NDArray array = array(dArr.length, dArr[0].length);
        for (int i = 0; i < dArr.length; i++) {
            array.setRow(i, rowVector(dArr[i]));
        }
        return array;
    }

    public NDArray array(float[][] fArr) {
        if (fArr.length == 0) {
            return empty();
        }
        NDArray array = array(fArr.length, fArr[0].length);
        for (int i = 0; i < fArr.length; i++) {
            array.setRow(i, rowVector(fArr[i]));
        }
        return array;
    }

    public NDArray array(float[][][] fArr) {
        if (fArr.length == 0) {
            return empty();
        }
        NDArray[] nDArrayArr = new NDArray[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            nDArrayArr[i] = array(fArr[i]);
        }
        return new Tensor(0, nDArrayArr.length, nDArrayArr);
    }

    public NDArray array(double[][][] dArr) {
        if (dArr.length == 0) {
            return empty();
        }
        NDArray[] nDArrayArr = new NDArray[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            nDArrayArr[i] = array(dArr[i]);
        }
        return new Tensor(0, nDArrayArr.length, nDArrayArr);
    }

    public abstract NDArray array(Shape shape);

    public NDArray array(Shape shape, NDArrayInitializer nDArrayInitializer) {
        NDArray array = array(shape);
        nDArrayInitializer.accept(array);
        return array;
    }

    public NDArray columnVector(double[] dArr) {
        NDArray array = array(dArr.length, 1);
        for (int i = 0; i < dArr.length; i++) {
            array.set(i, dArr[i]);
        }
        return array;
    }

    public NDArray columnVector(float[] fArr) {
        NDArray array = array(fArr.length, 1);
        for (int i = 0; i < fArr.length; i++) {
            array.set(i, fArr[i]);
        }
        return array;
    }

    public NDArray constant(@NonNull Shape shape, double d) {
        if (shape == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        return array(shape).fill(d);
    }

    public NDArray empty() {
        return array(Shape.empty());
    }

    public NDArray eye(int i) {
        NDArray array = array(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            array.set(i2, i2, 1.0d);
        }
        return array;
    }

    public NDArray fromTensorFlowTensor(@NonNull org.tensorflow.Tensor<?> tensor) {
        if (tensor == null) {
            throw new NullPointerException("tensor is marked non-null but is null");
        }
        Validation.checkArgument(tensor.shape().length <= 3, "Only tensors of rank 3 or less are supported.");
        long[] shape = tensor.shape();
        if (shape.length != 3) {
            float[][] fArr = new float[(int) shape[0]][(int) shape[1]];
            tensor.copyTo(fArr);
            return array(fArr);
        }
        Tensor tensor2 = new Tensor(Shape.shape((int) shape[0], (int) shape[1], (int) shape[2]));
        float[][][] fArr2 = new float[(int) shape[0]][(int) shape[1]][(int) shape[2]];
        tensor.copyTo(fArr2);
        for (int i = 0; i < fArr2.length; i++) {
            tensor2.setSlice(i, array(fArr2[i]));
        }
        return tensor2;
    }

    public NDArray hstack(@NonNull NDArray... nDArrayArr) {
        if (nDArrayArr == null) {
            throw new NullPointerException("arrays is marked non-null but is null");
        }
        return hstack(Arrays.asList(nDArrayArr));
    }

    public NDArray hstack(@NonNull Collection<NDArray> collection) {
        if (collection == null) {
            throw new NullPointerException("arrays is marked non-null but is null");
        }
        if (collection.size() == 0) {
            return empty();
        }
        if (collection.stream().mapToLong((v0) -> {
            return v0.rows();
        }).distinct().count() > 1) {
            throw new IllegalArgumentException("Row mismatch: " + Arrays.toString(collection.stream().mapToLong((v0) -> {
                return v0.rows();
            }).toArray()));
        }
        NDArray array = array(collection.iterator().next().shape().rows(), (int) collection.stream().mapToLong((v0) -> {
            return v0.columns();
        }).sum());
        int i = 0;
        for (NDArray nDArray : collection) {
            int i2 = i;
            nDArray.forEachSparse((j, d) -> {
                array.set(nDArray.shape.toRow((int) j), nDArray.shape.toColumn((int) j) + i2, d);
            });
            i += nDArray.columns();
        }
        return array;
    }

    public NDArray ones(int... iArr) {
        return constant(Shape.shape(iArr), 1.0d);
    }

    public NDArray ones(Shape shape) {
        return constant(shape, 1.0d);
    }

    public NDArray rand(int... iArr) {
        return array(Shape.shape(iArr), NDArrayInitializer.rand);
    }

    public NDArray rand(Shape shape) {
        return array(shape, NDArrayInitializer.rand);
    }

    public NDArray randn(int... iArr) {
        return array(Shape.shape(iArr), NDArrayInitializer.randn(new Random()));
    }

    public NDArray randn(Shape shape) {
        return array(shape, NDArrayInitializer.randn(new Random()));
    }

    public NDArray rowVector(double[] dArr) {
        NDArray array = array(1, dArr.length);
        for (int i = 0; i < dArr.length; i++) {
            array.set(i, dArr[i]);
        }
        return array;
    }

    public NDArray rowVector(float[] fArr) {
        NDArray array = array(1, fArr.length);
        for (int i = 0; i < fArr.length; i++) {
            array.set(i, fArr[i]);
        }
        return array;
    }

    public NDArray scalar(double d) {
        NDArray array = array(new int[0]);
        array.set(0L, d);
        return array;
    }

    public NDArray uniform(@NonNull Shape shape, int i, int i2) {
        if (shape == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        return array(shape, NDArrayInitializer.rand(i, i2));
    }

    public NDArray vstack(@NonNull NDArray... nDArrayArr) {
        if (nDArrayArr == null) {
            throw new NullPointerException("arrays is marked non-null but is null");
        }
        return vstack(Arrays.asList(nDArrayArr));
    }

    public NDArray vstack(@NonNull Collection<NDArray> collection) {
        if (collection == null) {
            throw new NullPointerException("arrays is marked non-null but is null");
        }
        if (collection.size() == 0) {
            return empty();
        }
        if (collection.stream().mapToLong((v0) -> {
            return v0.columns();
        }).distinct().count() > 1) {
            throw new IllegalArgumentException("Column mismatch");
        }
        NDArray array = array(collection.stream().mapToInt((v0) -> {
            return v0.rows();
        }).sum(), collection.iterator().next().shape().columns());
        int i = 0;
        for (NDArray nDArray : collection) {
            int i2 = i;
            nDArray.forEachSparse((j, d) -> {
                int row = nDArray.shape.toRow((int) j);
                array.set(row + i2, nDArray.shape.toColumn((int) j), d);
            });
            i += nDArray.rows();
        }
        return array;
    }
}
