package ai.djl.nn;

import ai.djl.ndarray.NDArray;

/* loaded from: input_file:ai/djl/nn/Blocks.class */
public final class Blocks {
    private Blocks() {
    }

    public static NDArray batchFlatten(NDArray nDArray) {
        long size = nDArray.size(0);
        return size == 0 ? nDArray.reshape(size, nDArray.getShape().slice(1).size()) : nDArray.reshape(size, -1);
    }

    public static NDArray batchFlatten(NDArray nDArray, long j) {
        return nDArray.reshape(-1, j);
    }

    public static Block batchFlattenBlock() {
        return LambdaBlock.singleton(Blocks::batchFlatten);
    }

    public static Block batchFlattenBlock(long j) {
        return LambdaBlock.singleton(nDArray -> {
            return batchFlatten(nDArray, j);
        });
    }

    public static Block identityBlock() {
        return new LambdaBlock(nDList -> {
            return nDList;
        });
    }
}
